Skip to main content

http_type/websocket_frame/
impl.rs

1use crate::*;
2
3/// Implements the `Default` trait for `WebSocketFrame`.
4///
5/// Provides a default `WebSocketFrame` with `fin: false`, `opcode: WebSocketOpcode::Text`,
6/// `mask: false`, and an empty `payload_data`.
7impl Default for WebSocketFrame {
8    /// Returns the default `WebSocketFrame`.
9    ///
10    /// # Returns
11    ///
12    /// - `Self` - A default `WebSocketFrame` instance.
13    #[inline(always)]
14    fn default() -> Self {
15        Self {
16            fin: false,
17            opcode: WebSocketOpcode::Text,
18            mask: false,
19            payload_data: Vec::new(),
20        }
21    }
22}
23
24impl WebSocketOpcode {
25    /// Creates a `WebSocketOpcode` from a raw u8 value.
26    ///
27    /// # Arguments
28    ///
29    /// - `u8`: The raw opcode value.
30    ///
31    /// # Returns
32    ///
33    /// - `Self` - A `WebSocketOpcode` enum variant corresponding to the raw value.
34    #[inline(always)]
35    pub fn from_u8(opcode: u8) -> Self {
36        match opcode {
37            0x0 => Self::Continuation,
38            0x1 => Self::Text,
39            0x2 => Self::Binary,
40            0x8 => Self::Close,
41            0x9 => Self::Ping,
42            0xA => Self::Pong,
43            _ => Self::Reserved(opcode),
44        }
45    }
46
47    /// Converts the `WebSocketOpcode` to its raw u8 value.
48    ///
49    /// # Returns
50    ///
51    /// - `u8` - The raw u8 value of the opcode.
52    #[inline(always)]
53    pub fn to_u8(&self) -> u8 {
54        match self {
55            Self::Continuation => 0x0,
56            Self::Text => 0x1,
57            Self::Binary => 0x2,
58            Self::Close => 0x8,
59            Self::Ping => 0x9,
60            Self::Pong => 0xA,
61            Self::Reserved(code) => *code,
62        }
63    }
64
65    /// Checks if the opcode is a control frame.
66    ///
67    /// # Returns
68    ///
69    /// - `bool` - `true` if the opcode represents a control frame (Close, Ping, Pong), otherwise `false`.
70    #[inline(always)]
71    pub fn is_control(&self) -> bool {
72        matches!(self, Self::Close | Self::Ping | Self::Pong)
73    }
74
75    /// Checks if the opcode is a data frame.
76    ///
77    /// # Returns
78    ///
79    /// - `bool` - `true` if the opcode represents a data frame (Text, Binary, Continuation), otherwise `false`.
80    #[inline(always)]
81    pub fn is_data(&self) -> bool {
82        matches!(self, Self::Text | Self::Binary | Self::Continuation)
83    }
84
85    /// Checks if the opcode is a continuation frame.
86    ///
87    /// # Returns
88    ///
89    /// - `bool` - `true` if the opcode is `Continuation`, otherwise `false`.
90    #[inline(always)]
91    pub fn is_continuation(&self) -> bool {
92        matches!(self, Self::Continuation)
93    }
94
95    /// Checks if the opcode is a text frame.
96    ///
97    /// # Returns
98    ///
99    /// - `bool` - `true` if the opcode is `Text`, otherwise `false`.
100    #[inline(always)]
101    pub fn is_text(&self) -> bool {
102        matches!(self, Self::Text)
103    }
104
105    /// Checks if the opcode is a binary frame.
106    ///
107    /// # Returns
108    ///
109    /// - `bool` - `true` if the opcode is `Binary`, otherwise `false`.
110    #[inline(always)]
111    pub fn is_binary(&self) -> bool {
112        matches!(self, Self::Binary)
113    }
114
115    /// Checks if the opcode is a close frame.
116    ///
117    /// # Returns
118    ///
119    /// - `bool` - `true` if the opcode is `Close`, otherwise `false`.
120    #[inline(always)]
121    pub fn is_close(&self) -> bool {
122        matches!(self, Self::Close)
123    }
124
125    /// Checks if the opcode is a ping frame.
126    ///
127    /// # Returns
128    ///
129    /// - `bool` - `true` if the opcode is `Ping`, otherwise `false`.
130    #[inline(always)]
131    pub fn is_ping(&self) -> bool {
132        matches!(self, Self::Ping)
133    }
134
135    /// Checks if the opcode is a pong frame.
136    ///
137    /// # Returns
138    ///
139    /// - `bool` - `true` if the opcode is `Pong`, otherwise `false`.
140    #[inline(always)]
141    pub fn is_pong(&self) -> bool {
142        matches!(self, Self::Pong)
143    }
144
145    /// Checks if the opcode is a reserved frame.
146    ///
147    /// # Returns
148    ///
149    /// - `bool` - `true` if the opcode is `Reserved(_)`, otherwise `false`.
150    #[inline(always)]
151    pub fn is_reserved(&self) -> bool {
152        matches!(self, Self::Reserved(_))
153    }
154}
155
156impl WebSocketFrame {
157    /// Decodes a WebSocket frame from the provided data slice.
158    ///
159    /// This function parses the raw bytes from a WebSocket stream according to the WebSocket protocol
160    /// specification to reconstruct a `WebSocketFrame`. It handles FIN bit, opcode, mask bit,
161    /// payload length (including extended lengths), mask key, and the payload data itself.
162    ///
163    /// # Arguments
164    ///
165    /// - `AsRef<[u8]>` - The raw data to decode into a WebSocket frame.
166    ///
167    /// # Returns
168    ///
169    /// - `Option<(WebSocketFrame, usize)>`
170    ///     - `Some((WebSocketFrame, usize))`: If the frame is successfully decoded, returns the decoded frame
171    ///       and the number of bytes consumed from the input slice.
172    ///     - `None`: If the frame is incomplete or malformed.
173    pub fn decode_ws_frame<D>(data: D) -> Option<(WebSocketFrame, usize)>
174    where
175        D: AsRef<[u8]>,
176    {
177        let data_ref: &[u8] = data.as_ref();
178        if data_ref.len() < 2 {
179            return None;
180        }
181        let mut index: usize = 0;
182        let fin: bool = (data_ref[index] & 0b1000_0000) != 0;
183        let opcode: WebSocketOpcode = WebSocketOpcode::from_u8(data_ref[index] & 0b0000_1111);
184        index += 1;
185        let mask: bool = (data_ref[index] & 0b1000_0000) != 0;
186        let mut payload_len: usize = (data_ref[index] & 0b0111_1111) as usize;
187        index += 1;
188        if payload_len == 126 {
189            if data_ref.len() < index + 2 {
190                return None;
191            }
192            payload_len = u16::from_be_bytes(data_ref[index..index + 2].try_into().ok()?) as usize;
193            index += 2;
194        } else if payload_len == 127 {
195            if data_ref.len() < index + 8 {
196                return None;
197            }
198            payload_len = u64::from_be_bytes(data_ref[index..index + 8].try_into().ok()?) as usize;
199            index += 8;
200        }
201        let mask_key: Option<[u8; 4]> = if mask {
202            if data_ref.len() < index + 4 {
203                return None;
204            }
205            let key: [u8; 4] = data_ref[index..index + 4].try_into().ok()?;
206            index += 4;
207            Some(key)
208        } else {
209            None
210        };
211        if data_ref.len() < index + payload_len {
212            return None;
213        }
214        let mut payload: Vec<u8> = data_ref[index..index + payload_len].to_vec();
215        if let Some(mask_key) = mask_key {
216            for (byte_index, payload_byte) in payload.iter_mut().enumerate() {
217                *payload_byte ^= mask_key[byte_index % 4];
218            }
219        }
220        index += payload_len;
221        let frame: WebSocketFrame = WebSocketFrame {
222            fin,
223            opcode,
224            mask,
225            payload_data: payload,
226        };
227        Some((frame, index))
228    }
229
230    /// Creates a list of response frames from the provided body.
231    ///
232    /// This method segments the response body into WebSocket frames, respecting the maximum frame size
233    /// and handling UTF-8 character boundaries for text frames. It determines the appropriate opcode
234    /// (Text or Binary) based on the body's content.
235    ///
236    /// # Arguments
237    ///
238    /// - `AsRef<[u8]>` - A reference to a response body (payload) as a byte slice.
239    ///
240    /// # Returns
241    ///
242    /// - `Vec<ResponseBody>` - A vector of `ResponseBody` (byte vectors), where each element represents a framed WebSocket message.
243    pub fn create_frame_list<D>(data: D) -> Vec<ResponseBody>
244    where
245        D: AsRef<[u8]>,
246    {
247        let data_ref: &[u8] = data.as_ref();
248        let total_len: usize = data_ref.len();
249        let mut offset: usize = 0;
250        let mut frames_list: Vec<ResponseBody> =
251            Vec::with_capacity((total_len / MAX_FRAME_SIZE) + 1);
252        let mut is_first_frame: bool = true;
253        let is_valid_utf8: bool = std::str::from_utf8(data_ref).is_ok();
254        let base_opcode: WebSocketOpcode = if is_valid_utf8 {
255            WebSocketOpcode::Text
256        } else {
257            WebSocketOpcode::Binary
258        };
259        while offset < total_len {
260            let remaining: usize = total_len - offset;
261            let mut frame_size: usize = remaining.min(MAX_FRAME_SIZE);
262            if is_valid_utf8 && frame_size < remaining {
263                while frame_size > 0 && (data_ref[offset + frame_size] & 0xC0) == 0x80 {
264                    frame_size -= 1;
265                }
266                if frame_size == 0 {
267                    frame_size = remaining.min(MAX_FRAME_SIZE);
268                }
269            }
270            let mut frame: ResponseBody = Vec::with_capacity(frame_size + 10);
271            let opcode: WebSocketOpcode = if is_first_frame {
272                base_opcode
273            } else {
274                WebSocketOpcode::Continuation
275            };
276            let fin: u8 = if remaining > frame_size { 0x00 } else { 0x80 };
277            let opcode_byte: u8 = opcode.to_u8() & 0x0F;
278            frame.push(fin | opcode_byte);
279            if frame_size < 126 {
280                frame.push(frame_size as u8);
281            } else if frame_size <= MAX_FRAME_SIZE {
282                frame.push(126);
283                frame.extend_from_slice(&(frame_size as u16).to_be_bytes());
284            } else {
285                frame.push(127);
286                frame.extend_from_slice(&(frame_size as u16).to_be_bytes());
287            }
288            let end: usize = offset + frame_size;
289            frame.extend_from_slice(&data_ref[offset..end]);
290            frames_list.push(frame);
291            offset = end;
292            is_first_frame = false;
293        }
294        frames_list
295    }
296
297    /// Calculates the SHA-1 hash of the input data.
298    ///
299    /// This function implements the SHA-1 cryptographic hash algorithm according to RFC 3174.
300    /// It processes the input data in 512-bit (64-byte) blocks and produces a 160-bit (20-byte) hash.
301    ///
302    /// # Arguments
303    ///
304    /// - `AsRef<[u8]>` - The input data to be hashed.
305    ///
306    /// # Returns
307    ///
308    /// - `[u8; 20]` - A 20-byte array representing the SHA-1 hash of the input data.
309    pub fn sha1<D>(data: D) -> [u8; 20]
310    where
311        D: AsRef<[u8]>,
312    {
313        let data_ref: &[u8] = data.as_ref();
314        let mut hash_state: [u32; 5] = HASH_STATE;
315        let mut padded_data: Vec<u8> = Vec::from(data_ref);
316        let original_size_bits: u64 = (padded_data.len() * 8) as u64;
317        padded_data.push(0x80);
318        while !(padded_data.len() + 8).is_multiple_of(64) {
319            padded_data.push(0);
320        }
321        padded_data.extend_from_slice(&original_size_bits.to_be_bytes());
322        for block in padded_data.chunks_exact(64) {
323            let mut message_schedule: [u32; 80] = [0u32; 80];
324            for (chunk_index, block_chunk) in block.chunks_exact(4).enumerate().take(16) {
325                message_schedule[chunk_index] = u32::from_be_bytes([
326                    block_chunk[0],
327                    block_chunk[1],
328                    block_chunk[2],
329                    block_chunk[3],
330                ]);
331            }
332            for schedule_index in 16..80 {
333                message_schedule[schedule_index] = (message_schedule[schedule_index - 3]
334                    ^ message_schedule[schedule_index - 8]
335                    ^ message_schedule[schedule_index - 14]
336                    ^ message_schedule[schedule_index - 16])
337                    .rotate_left(1);
338            }
339            let [mut hash_a, mut hash_b, mut hash_c, mut hash_d, mut hash_e] = hash_state;
340            for (round_index, &schedule_word) in message_schedule.iter().enumerate() {
341                let (round_function, round_constant): (u32, u32) = match round_index {
342                    0..=19 => ((hash_b & hash_c) | (!hash_b & hash_d), 0x5A827999),
343                    20..=39 => (hash_b ^ hash_c ^ hash_d, 0x6ED9EBA1),
344                    40..=59 => (
345                        (hash_b & hash_c) | (hash_b & hash_d) | (hash_c & hash_d),
346                        0x8F1BBCDC,
347                    ),
348                    _ => (hash_b ^ hash_c ^ hash_d, 0xCA62C1D6),
349                };
350                let temp: u32 = hash_a
351                    .rotate_left(5)
352                    .wrapping_add(round_function)
353                    .wrapping_add(hash_e)
354                    .wrapping_add(round_constant)
355                    .wrapping_add(schedule_word);
356                hash_e = hash_d;
357                hash_d = hash_c;
358                hash_c = hash_b.rotate_left(30);
359                hash_b = hash_a;
360                hash_a = temp;
361            }
362            hash_state[0] = hash_state[0].wrapping_add(hash_a);
363            hash_state[1] = hash_state[1].wrapping_add(hash_b);
364            hash_state[2] = hash_state[2].wrapping_add(hash_c);
365            hash_state[3] = hash_state[3].wrapping_add(hash_d);
366            hash_state[4] = hash_state[4].wrapping_add(hash_e);
367        }
368        let mut result: [u8; 20] = [0u8; 20];
369        for (state_index, &state_value) in hash_state.iter().enumerate() {
370            result[state_index * 4..(state_index + 1) * 4]
371                .copy_from_slice(&state_value.to_be_bytes());
372        }
373        result
374    }
375
376    /// Generates a WebSocket accept key from the client-provided key, returning an `Option<String>`.
377    ///
378    /// # Arguments
379    ///
380    /// - `AsRef<str>` - The client-provided key (typically from the `Sec-WebSocket-Key` header).
381    ///
382    /// # Returns
383    ///
384    /// - `Option<String>` - An optional string representing the generated WebSocket accept key (typically for the `Sec-WebSocket-Accept` header).
385    #[inline(always)]
386    pub fn try_generate_accept_key<K>(key: K) -> Option<String>
387    where
388        K: AsRef<str>,
389    {
390        let key_ref: &str = key.as_ref();
391        let mut data: [u8; 60] = [0u8; 60];
392        data[..24].copy_from_slice(&key_ref.as_bytes()[..24.min(key_ref.len())]);
393        data[24..].copy_from_slice(GUID);
394        let hash: [u8; 20] = Self::sha1(data);
395        Self::try_base64_encode(hash)
396    }
397
398    /// Generates a WebSocket accept key from the client-provided key.
399    ///
400    /// This function is used during the WebSocket handshake to validate the client's request.
401    /// It concatenates the client's key with a specific GUID, calculates the SHA-1 hash of the result,
402    /// and then encodes the hash in base64.
403    ///
404    /// # Arguments
405    ///
406    /// - `AsRef<str>` - The client-provided key (typically from the `Sec-WebSocket-Key` header).
407    ///
408    /// # Returns
409    ///
410    /// - `Option<String>` - An optional string representing the generated WebSocket accept key (typically for the `Sec-WebSocket-Accept` header).
411    ///
412    /// # Panics
413    ///
414    /// This function will panic if the input key cannot be converted to a UTF-8 string.
415    #[inline(always)]
416    pub fn generate_accept_key<K>(key: K) -> String
417    where
418        K: AsRef<str>,
419    {
420        let key_ref: &str = key.as_ref();
421        let mut data: [u8; 60] = [0u8; 60];
422        data[..24].copy_from_slice(&key_ref.as_bytes()[..24.min(key_ref.len())]);
423        data[24..].copy_from_slice(GUID);
424        let hash: [u8; 20] = Self::sha1(data);
425        Self::base64_encode(hash)
426    }
427
428    /// Encodes the input data as a base64 string, returning an `Option<String>`.
429    ///
430    /// # Arguments
431    ///
432    /// - `AsRef<[u8]>` - The data to encode in base64.
433    ///
434    /// # Returns
435    ///
436    /// - `Option<String>` - An optional string with the base64 encoded representation of the input data.
437    pub fn try_base64_encode<D>(data: D) -> Option<String>
438    where
439        D: AsRef<[u8]>,
440    {
441        let data_ref: &[u8] = data.as_ref();
442        let mut encoded_data: Vec<u8> = Vec::with_capacity(data_ref.len().div_ceil(3) * 4);
443        for chunk in data_ref.chunks(3) {
444            let mut buffer: [u8; 3] = [0u8; 3];
445            buffer[..chunk.len()].copy_from_slice(chunk);
446            let indices: [u8; 4] = [
447                buffer[0] >> 2,
448                ((buffer[0] & 0b11) << 4) | (buffer[1] >> 4),
449                ((buffer[1] & 0b1111) << 2) | (buffer[2] >> 6),
450                buffer[2] & 0b111111,
451            ];
452            for &idx in &indices[..chunk.len() + 1] {
453                encoded_data.push(BASE64_CHARSET_TABLE[idx as usize]);
454            }
455            while !encoded_data.len().is_multiple_of(4) {
456                encoded_data.push(EQUAL_BYTES[0]);
457            }
458        }
459        String::from_utf8(encoded_data).ok()
460    }
461
462    /// Encodes the input data as a base64 string.
463    ///
464    /// # Arguments
465    ///
466    /// - `AsRef<[u8]>` - The data to encode in base64.
467    ///
468    /// # Returns
469    ///
470    /// - `String` - A string with the base64 encoded representation of the input data.
471    ///
472    /// # Panics
473    ///
474    /// This function will panic if the input data cannot be converted to a UTF-8 string.
475    #[inline(always)]
476    pub fn base64_encode<D>(data: D) -> String
477    where
478        D: AsRef<[u8]>,
479    {
480        Self::try_base64_encode(data).unwrap()
481    }
482
483    /// Checks if the opcode is a continuation frame.
484    ///
485    /// # Returns
486    ///
487    /// - `bool` - `true` if the opcode is `Continuation`, otherwise `false`.
488    #[inline(always)]
489    pub fn is_continuation_opcode(&self) -> bool {
490        self.opcode.is_continuation()
491    }
492
493    /// Checks if the opcode is a text frame.
494    ///
495    /// # Returns
496    ///
497    /// - `bool` - `true` if the opcode is `Text`, otherwise `false`.
498    #[inline(always)]
499    pub fn is_text_opcode(&self) -> bool {
500        self.opcode.is_text()
501    }
502
503    /// Checks if the opcode is a binary frame.
504    ///
505    /// # Returns
506    ///
507    /// - `bool` - `true` if the opcode is `Binary`, otherwise `false`.
508    #[inline(always)]
509    pub fn is_binary_opcode(&self) -> bool {
510        self.opcode.is_binary()
511    }
512
513    /// Checks if the opcode is a close frame.
514    ///
515    /// # Returns
516    ///
517    /// - `bool` - `true` if the opcode is `Close`, otherwise `false`.
518    #[inline(always)]
519    pub fn is_close_opcode(&self) -> bool {
520        self.opcode.is_close()
521    }
522
523    /// Checks if the opcode is a ping frame.
524    ///
525    /// # Returns
526    ///
527    /// - `bool` - `true` if the opcode is `Ping`, otherwise `false`.
528    #[inline(always)]
529    pub fn is_ping_opcode(&self) -> bool {
530        self.opcode.is_ping()
531    }
532
533    /// Checks if the opcode is a pong frame.
534    ///
535    /// # Returns
536    ///
537    /// - `bool` - `true` if the opcode is `Pong`, otherwise `false`.
538    #[inline(always)]
539    pub fn is_pong_opcode(&self) -> bool {
540        self.opcode.is_pong()
541    }
542
543    /// Checks if the opcode is a reserved frame.
544    ///
545    /// # Returns
546    ///
547    /// - `bool` - `true` if the opcode is `Reserved(_)`, otherwise `false`.
548    #[inline(always)]
549    pub fn is_reserved_opcode(&self) -> bool {
550        self.opcode.is_reserved()
551    }
552
553    /// Handles a decoded WebSocket Text or Binary frame and accumulates payload data.
554    ///
555    /// # Arguments
556    ///
557    /// - `&mut Vec<u8>`: The accumulated frame data.
558    ///
559    /// # Returns
560    ///
561    /// - `Result<Option<RequestBody>, RequestError>`: Some(request) if frame is complete, None to continue, or error.
562    #[inline(always)]
563    pub(crate) fn build_full_frame(
564        &self,
565        full_frame: &mut Vec<u8>,
566    ) -> Result<Option<RequestBody>, RequestError> {
567        let payload_data: &[u8] = self.get_payload_data();
568        full_frame.extend_from_slice(payload_data);
569        if *self.get_fin() {
570            return Ok(Some(full_frame.clone()));
571        }
572        Ok(None)
573    }
574}