http_type/websocket_frame/
impl.rs

1use crate::*;
2
3/// Implements the `Default` trait for `WebSocketFrame`.
4///
5/// Provides a default `WebSocketFrame` with `fin: false`, `opcode: WebSocketOpcode::Text`,
6/// `mask: false`, and an empty `payload_data`.
7impl Default for WebSocketFrame {
8    /// Returns the default `WebSocketFrame`.
9    ///
10    /// # Returns
11    ///
12    /// A default `WebSocketFrame` instance.
13    #[inline(always)]
14    fn default() -> Self {
15        Self {
16            fin: false,
17            opcode: WebSocketOpcode::Text,
18            mask: false,
19            payload_data: Vec::new(),
20        }
21    }
22}
23
24impl WebSocketOpcode {
25    /// Creates a `WebSocketOpcode` from a raw u8 value.
26    ///
27    /// # Arguments
28    ///
29    /// - `opcode`: The raw opcode value.
30    ///
31    /// # Returns
32    ///
33    /// A `WebSocketOpcode` enum variant corresponding to the raw value.
34    #[inline(always)]
35    pub fn from_u8(opcode: u8) -> Self {
36        match opcode {
37            0x0 => Self::Continuation,
38            0x1 => Self::Text,
39            0x2 => Self::Binary,
40            0x8 => Self::Close,
41            0x9 => Self::Ping,
42            0xA => Self::Pong,
43            _ => Self::Reserved(opcode),
44        }
45    }
46
47    /// Converts the `WebSocketOpcode` to its raw u8 value.
48    ///
49    /// # Returns
50    ///
51    /// The raw u8 value of the opcode.
52    #[inline(always)]
53    pub fn to_u8(&self) -> u8 {
54        match self {
55            Self::Continuation => 0x0,
56            Self::Text => 0x1,
57            Self::Binary => 0x2,
58            Self::Close => 0x8,
59            Self::Ping => 0x9,
60            Self::Pong => 0xA,
61            Self::Reserved(code) => *code,
62        }
63    }
64
65    /// Checks if the opcode is a control frame.
66    ///
67    /// # Returns
68    ///
69    /// `true` if the opcode represents a control frame (Close, Ping, Pong), otherwise `false`.
70    #[inline(always)]
71    pub fn is_control(&self) -> bool {
72        matches!(self, Self::Close | Self::Ping | Self::Pong)
73    }
74
75    /// Checks if the opcode is a data frame.
76    ///
77    /// # Returns
78    ///
79    /// `true` if the opcode represents a data frame (Text, Binary, Continuation), otherwise `false`.
80    #[inline(always)]
81    pub fn is_data(&self) -> bool {
82        matches!(self, Self::Text | Self::Binary | Self::Continuation)
83    }
84
85    /// Checks if the opcode is a continuation frame.
86    ///
87    /// # Returns
88    ///
89    /// `true` if the opcode is `Continuation`, otherwise `false`.
90    #[inline(always)]
91    pub fn is_continuation(&self) -> bool {
92        matches!(self, Self::Continuation)
93    }
94
95    /// Checks if the opcode is a text frame.
96    ///
97    /// # Returns
98    ///
99    /// `true` if the opcode is `Text`, otherwise `false`.
100    #[inline(always)]
101    pub fn is_text(&self) -> bool {
102        matches!(self, Self::Text)
103    }
104
105    /// Checks if the opcode is a binary frame.
106    ///
107    /// # Returns
108    ///
109    /// `true` if the opcode is `Binary`, otherwise `false`.
110    #[inline(always)]
111    pub fn is_binary(&self) -> bool {
112        matches!(self, Self::Binary)
113    }
114
115    /// Checks if the opcode is a close frame.
116    ///
117    /// # Returns
118    ///
119    /// `true` if the opcode is `Close`, otherwise `false`.
120    #[inline(always)]
121    pub fn is_close(&self) -> bool {
122        matches!(self, Self::Close)
123    }
124
125    /// Checks if the opcode is a ping frame.
126    ///
127    /// # Returns
128    ///
129    /// `true` if the opcode is `Ping`, otherwise `false`.
130    #[inline(always)]
131    pub fn is_ping(&self) -> bool {
132        matches!(self, Self::Ping)
133    }
134
135    /// Checks if the opcode is a pong frame.
136    ///
137    /// # Returns
138    ///
139    /// `true` if the opcode is `Pong`, otherwise `false`.
140    #[inline(always)]
141    pub fn is_pong(&self) -> bool {
142        matches!(self, Self::Pong)
143    }
144
145    /// Checks if the opcode is a reserved frame.
146    ///
147    /// # Returns
148    ///
149    /// `true` if the opcode is `Reserved(_)`, otherwise `false`.
150    #[inline(always)]
151    pub fn is_reserved(&self) -> bool {
152        matches!(self, Self::Reserved(_))
153    }
154}
155
156impl WebSocketFrame {
157    /// Decodes a WebSocket frame from the provided data slice.
158    ///
159    /// This function parses the raw bytes from a WebSocket stream according to the WebSocket protocol
160    /// specification to reconstruct a `WebSocketFrame`. It handles FIN bit, opcode, mask bit,
161    /// payload length (including extended lengths), mask key, and the payload data itself.
162    ///
163    /// # Arguments
164    ///
165    /// - `AsRef<[u8]>` - The raw data to decode into a WebSocket frame.
166    ///
167    /// # Returns
168    ///
169    /// - `Some((WebSocketFrame, usize))`: If the frame is successfully decoded, returns the decoded frame
170    ///   and the number of bytes consumed from the input slice.
171    /// - `None`: If the frame is incomplete or malformed.
172    pub fn decode_ws_frame<D>(data: D) -> Option<(WebSocketFrame, usize)>
173    where
174        D: AsRef<[u8]>,
175    {
176        let data_ref: &[u8] = data.as_ref();
177        if data_ref.len() < 2 {
178            return None;
179        }
180        let mut index: usize = 0;
181        let fin: bool = (data_ref[index] & 0b1000_0000) != 0;
182        let opcode: WebSocketOpcode = WebSocketOpcode::from_u8(data_ref[index] & 0b0000_1111);
183        index += 1;
184        let mask: bool = (data_ref[index] & 0b1000_0000) != 0;
185        let mut payload_len: usize = (data_ref[index] & 0b0111_1111) as usize;
186        index += 1;
187        if payload_len == 126 {
188            if data_ref.len() < index + 2 {
189                return None;
190            }
191            payload_len = u16::from_be_bytes(data_ref[index..index + 2].try_into().ok()?) as usize;
192            index += 2;
193        } else if payload_len == 127 {
194            if data_ref.len() < index + 8 {
195                return None;
196            }
197            payload_len = u64::from_be_bytes(data_ref[index..index + 8].try_into().ok()?) as usize;
198            index += 8;
199        }
200        let mask_key: Option<[u8; 4]> = if mask {
201            if data_ref.len() < index + 4 {
202                return None;
203            }
204            let key: [u8; 4] = data_ref[index..index + 4].try_into().ok()?;
205            index += 4;
206            Some(key)
207        } else {
208            None
209        };
210        if data_ref.len() < index + payload_len {
211            return None;
212        }
213        let mut payload: Vec<u8> = data_ref[index..index + payload_len].to_vec();
214        if let Some(mask_key) = mask_key {
215            for (i, byte) in payload.iter_mut().enumerate() {
216                *byte ^= mask_key[i % 4];
217            }
218        }
219        index += payload_len;
220        let frame: WebSocketFrame = WebSocketFrame {
221            fin,
222            opcode,
223            mask,
224            payload_data: payload,
225        };
226        Some((frame, index))
227    }
228
229    /// Creates a list of response frames from the provided body.
230    ///
231    /// This method segments the response body into WebSocket frames, respecting the maximum frame size
232    /// and handling UTF-8 character boundaries for text frames. It determines the appropriate opcode
233    /// (Text or Binary) based on the body's content.
234    ///
235    /// # Arguments
236    ///
237    /// - `AsRef<[u8]>` - A reference to a response body (payload) as a byte slice.
238    ///
239    /// # Returns
240    ///
241    /// - A vector of `ResponseBody` (byte vectors), where each element represents a framed WebSocket message.
242    pub fn create_frame_list<D>(data: D) -> Vec<ResponseBody>
243    where
244        D: AsRef<[u8]>,
245    {
246        let data_ref: &[u8] = data.as_ref();
247        let total_len: usize = data_ref.len();
248        let mut offset: usize = 0;
249        let mut frames_list: Vec<ResponseBody> =
250            Vec::with_capacity((total_len / MAX_FRAME_SIZE) + 1);
251        let mut is_first_frame: bool = true;
252        let is_valid_utf8: bool = std::str::from_utf8(data_ref).is_ok();
253        let base_opcode: WebSocketOpcode = if is_valid_utf8 {
254            WebSocketOpcode::Text
255        } else {
256            WebSocketOpcode::Binary
257        };
258        while offset < total_len {
259            let remaining: usize = total_len - offset;
260            let mut frame_size: usize = remaining.min(MAX_FRAME_SIZE);
261            if is_valid_utf8 && frame_size < remaining {
262                while frame_size > 0 && (data_ref[offset + frame_size] & 0xC0) == 0x80 {
263                    frame_size -= 1;
264                }
265                if frame_size == 0 {
266                    frame_size = remaining.min(MAX_FRAME_SIZE);
267                }
268            }
269            let mut frame: ResponseBody = Vec::with_capacity(frame_size + 10);
270            let opcode: WebSocketOpcode = if is_first_frame {
271                base_opcode
272            } else {
273                WebSocketOpcode::Continuation
274            };
275            let fin: u8 = if remaining > frame_size { 0x00 } else { 0x80 };
276            let opcode_byte: u8 = opcode.to_u8() & 0x0F;
277            frame.push(fin | opcode_byte);
278            if frame_size < 126 {
279                frame.push(frame_size as u8);
280            } else if frame_size <= MAX_FRAME_SIZE {
281                frame.push(126);
282                frame.extend_from_slice(&(frame_size as u16).to_be_bytes());
283            } else {
284                frame.push(127);
285                frame.extend_from_slice(&(frame_size as u16).to_be_bytes());
286            }
287            let end: usize = offset + frame_size;
288            frame.extend_from_slice(&data_ref[offset..end]);
289            frames_list.push(frame);
290            offset = end;
291            is_first_frame = false;
292        }
293        frames_list
294    }
295
296    /// Creates a ping frame for WebSocket heartbeat/ping mechanism.
297    ///
298    /// # Returns
299    ///
300    /// - A byte vector representing a WebSocket ping frame.
301    #[inline(always)]
302    pub fn create_ping_frame() -> Vec<u8> {
303        let mut frame: Vec<u8> = Vec::with_capacity(2);
304        frame.push(0x89);
305        frame.push(0x00);
306        frame
307    }
308
309    /// Calculates the SHA-1 hash of the input data.
310    ///
311    /// This function implements the SHA-1 cryptographic hash algorithm according to RFC 3174.
312    /// It processes the input data in 512-bit (64-byte) blocks and produces a 160-bit (20-byte) hash.
313    ///
314    /// # Arguments
315    ///
316    /// - `AsRef<[u8]>` - The input data to be hashed.
317    ///
318    /// # Returns
319    ///
320    /// - A 20-byte array representing the SHA-1 hash of the input data.
321    pub fn sha1<D>(data: D) -> [u8; 20]
322    where
323        D: AsRef<[u8]>,
324    {
325        let data_ref: &[u8] = data.as_ref();
326        let mut hash_state: [u32; 5] = HASH_STATE;
327        let mut padded_data: Vec<u8> = Vec::from(data_ref);
328        let original_length_bits: u64 = (padded_data.len() * 8) as u64;
329        padded_data.push(0x80);
330        while !(padded_data.len() + 8).is_multiple_of(64) {
331            padded_data.push(0);
332        }
333        padded_data.extend_from_slice(&original_length_bits.to_be_bytes());
334        for block in padded_data.chunks_exact(64) {
335            let mut message_schedule: [u32; 80] = [0u32; 80];
336            for (i, block_chunk) in block.chunks_exact(4).enumerate().take(16) {
337                message_schedule[i] = u32::from_be_bytes([
338                    block_chunk[0],
339                    block_chunk[1],
340                    block_chunk[2],
341                    block_chunk[3],
342                ]);
343            }
344            for i in 16..80 {
345                message_schedule[i] = (message_schedule[i - 3]
346                    ^ message_schedule[i - 8]
347                    ^ message_schedule[i - 14]
348                    ^ message_schedule[i - 16])
349                    .rotate_left(1);
350            }
351            let [mut a, mut b, mut c, mut d, mut e] = hash_state;
352            for (i, &word) in message_schedule.iter().enumerate() {
353                let (f, k) = match i {
354                    0..=19 => ((b & c) | (!b & d), 0x5A827999),
355                    20..=39 => (b ^ c ^ d, 0x6ED9EBA1),
356                    40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC),
357                    _ => (b ^ c ^ d, 0xCA62C1D6),
358                };
359                let temp: u32 = a
360                    .rotate_left(5)
361                    .wrapping_add(f)
362                    .wrapping_add(e)
363                    .wrapping_add(k)
364                    .wrapping_add(word);
365                e = d;
366                d = c;
367                c = b.rotate_left(30);
368                b = a;
369                a = temp;
370            }
371            hash_state[0] = hash_state[0].wrapping_add(a);
372            hash_state[1] = hash_state[1].wrapping_add(b);
373            hash_state[2] = hash_state[2].wrapping_add(c);
374            hash_state[3] = hash_state[3].wrapping_add(d);
375            hash_state[4] = hash_state[4].wrapping_add(e);
376        }
377        let mut result: [u8; 20] = [0u8; 20];
378        for (i, &val) in hash_state.iter().enumerate() {
379            result[i * 4..(i + 1) * 4].copy_from_slice(&val.to_be_bytes());
380        }
381        result
382    }
383
384    /// Generates a WebSocket accept key from the client-provided key, returning an `Option<String>`.
385    ///
386    /// # Arguments
387    ///
388    /// - `AsRef<str>` - The client-provided key (typically from the `Sec-WebSocket-Key` header).
389    ///
390    /// # Returns
391    ///
392    /// - `Option<String>` - An optional string representing the generated WebSocket accept key (typically for the `Sec-WebSocket-Accept` header).
393    pub fn try_generate_accept_key<K>(key: K) -> Option<String>
394    where
395        K: AsRef<str>,
396    {
397        let key_ref: &str = key.as_ref();
398        let mut data: [u8; 60] = [0u8; 60];
399        data[..24].copy_from_slice(&key_ref.as_bytes()[..24.min(key_ref.len())]);
400        data[24..].copy_from_slice(GUID);
401        let hash: [u8; 20] = Self::sha1(data);
402        Self::try_base64_encode(hash)
403    }
404
405    /// Generates a WebSocket accept key from the client-provided key.
406    ///
407    /// This function is used during the WebSocket handshake to validate the client's request.
408    /// It concatenates the client's key with a specific GUID, calculates the SHA-1 hash of the result,
409    /// and then encodes the hash in base64.
410    ///
411    /// # Arguments
412    ///
413    /// - `AsRef<str>` - The client-provided key (typically from the `Sec-WebSocket-Key` header).
414    ///
415    /// # Returns
416    ///
417    /// - `Option<String>` - An optional string representing the generated WebSocket accept key (typically for the `Sec-WebSocket-Accept` header).
418    ///
419    /// # Panics
420    ///
421    /// This function will panic if the input key cannot be converted to a UTF-8 string.
422    pub fn generate_accept_key<K>(key: K) -> String
423    where
424        K: AsRef<str>,
425    {
426        let key_ref: &str = key.as_ref();
427        let mut data: [u8; 60] = [0u8; 60];
428        data[..24].copy_from_slice(&key_ref.as_bytes()[..24.min(key_ref.len())]);
429        data[24..].copy_from_slice(GUID);
430        let hash: [u8; 20] = Self::sha1(data);
431        Self::base64_encode(hash)
432    }
433
434    /// Encodes the input data as a base64 string, returning an `Option<String>`.
435    ///
436    /// # Arguments
437    ///
438    /// - `AsRef<[u8]>` - The data to encode in base64.
439    ///
440    /// # Returns
441    ///
442    /// - `Option<String>` - An optional string with the base64 encoded representation of the input data.
443    pub fn try_base64_encode<D>(data: D) -> Option<String>
444    where
445        D: AsRef<[u8]>,
446    {
447        let data_ref: &[u8] = data.as_ref();
448        let mut encoded_data: Vec<u8> = Vec::with_capacity(data_ref.len().div_ceil(3) * 4);
449        for chunk in data_ref.chunks(3) {
450            let mut buffer: [u8; 3] = [0u8; 3];
451            buffer[..chunk.len()].copy_from_slice(chunk);
452            let indices: [u8; 4] = [
453                buffer[0] >> 2,
454                ((buffer[0] & 0b11) << 4) | (buffer[1] >> 4),
455                ((buffer[1] & 0b1111) << 2) | (buffer[2] >> 6),
456                buffer[2] & 0b111111,
457            ];
458            for &idx in &indices[..chunk.len() + 1] {
459                encoded_data.push(BASE64_CHARSET_TABLE[idx as usize]);
460            }
461            while !encoded_data.len().is_multiple_of(4) {
462                encoded_data.push(EQUAL_BYTES[0]);
463            }
464        }
465        String::from_utf8(encoded_data).ok()
466    }
467
468    /// Encodes the input data as a base64 string.
469    ///
470    /// # Arguments
471    ///
472    /// - `AsRef<[u8]>` - The data to encode in base64.
473    ///
474    /// # Returns
475    ///
476    /// - A string with the base64 encoded representation of the input data.
477    ///
478    /// # Panics
479    ///
480    /// This function will panic if the input data cannot be converted to a UTF-8 string.
481    pub fn base64_encode<D>(data: D) -> String
482    where
483        D: AsRef<[u8]>,
484    {
485        Self::try_base64_encode(data).unwrap()
486    }
487
488    /// Checks if the opcode is a continuation frame.
489    ///
490    /// # Returns
491    ///
492    /// `true` if the opcode is `Continuation`, otherwise `false`.
493    #[inline(always)]
494    pub fn is_continuation_opcode(&self) -> bool {
495        self.opcode.is_continuation()
496    }
497
498    /// Checks if the opcode is a text frame.
499    ///
500    /// # Returns
501    ///
502    /// `true` if the opcode is `Text`, otherwise `false`.
503    #[inline(always)]
504    pub fn is_text_opcode(&self) -> bool {
505        self.opcode.is_text()
506    }
507
508    /// Checks if the opcode is a binary frame.
509    ///
510    /// # Returns
511    ///
512    /// `true` if the opcode is `Binary`, otherwise `false`.
513    #[inline(always)]
514    pub fn is_binary_opcode(&self) -> bool {
515        self.opcode.is_binary()
516    }
517
518    /// Checks if the opcode is a close frame.
519    ///
520    /// # Returns
521    ///
522    /// `true` if the opcode is `Close`, otherwise `false`.
523    #[inline(always)]
524    pub fn is_close_opcode(&self) -> bool {
525        self.opcode.is_close()
526    }
527
528    /// Checks if the opcode is a ping frame.
529    ///
530    /// # Returns
531    ///
532    /// `true` if the opcode is `Ping`, otherwise `false`.
533    #[inline(always)]
534    pub fn is_ping_opcode(&self) -> bool {
535        self.opcode.is_ping()
536    }
537
538    /// Checks if the opcode is a pong frame.
539    ///
540    /// # Returns
541    ///
542    /// `true` if the opcode is `Pong`, otherwise `false`.
543    #[inline(always)]
544    pub fn is_pong_opcode(&self) -> bool {
545        self.opcode.is_pong()
546    }
547
548    /// Checks if the opcode is a reserved frame.
549    ///
550    /// # Returns
551    ///
552    /// `true` if the opcode is `Reserved(_)`, otherwise `false`.
553    #[inline(always)]
554    pub fn is_reserved_opcode(&self) -> bool {
555        self.opcode.is_reserved()
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_create_ping_frame() {
565        let ping_frame: Vec<u8> = WebSocketFrame::create_ping_frame();
566        assert_eq!(ping_frame.len(), 2);
567        assert_eq!(ping_frame[0], 0x89);
568        assert_eq!(ping_frame[1], 0x00);
569    }
570
571    #[test]
572    fn test_ping_frame_decode() {
573        let ping_frame: Vec<u8> = WebSocketFrame::create_ping_frame();
574        let decoded: Option<(WebSocketFrame, usize)> = WebSocketFrame::decode_ws_frame(&ping_frame);
575        assert!(decoded.is_some());
576        let (frame, consumed): (WebSocketFrame, usize) = decoded.unwrap();
577        assert_eq!(consumed, 2);
578        assert!(frame.is_ping_opcode());
579        assert!(frame.get_payload_data().is_empty());
580    }
581}