Skip to main content

pipa/http/ws/
frame.rs

1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2pub enum OpCode {
3    Continuation = 0x0,
4    Text = 0x1,
5    Binary = 0x2,
6    Close = 0x8,
7    Ping = 0x9,
8    Pong = 0xA,
9}
10
11impl OpCode {
12    pub fn from_u4(v: u8) -> Option<Self> {
13        match v {
14            0x0 => Some(OpCode::Continuation),
15            0x1 => Some(OpCode::Text),
16            0x2 => Some(OpCode::Binary),
17            0x8 => Some(OpCode::Close),
18            0x9 => Some(OpCode::Ping),
19            0xA => Some(OpCode::Pong),
20            _ => None,
21        }
22    }
23
24    pub fn is_control(&self) -> bool {
25        matches!(self, OpCode::Close | OpCode::Ping | OpCode::Pong)
26    }
27}
28
29#[derive(Debug, Clone)]
30pub struct WsFrame {
31    pub fin: bool,
32    pub opcode: OpCode,
33    pub mask: Option<[u8; 4]>,
34    pub payload: Vec<u8>,
35}
36
37impl WsFrame {
38    pub fn new_text(payload: Vec<u8>) -> Self {
39        WsFrame {
40            fin: true,
41            opcode: OpCode::Text,
42            mask: Some(generate_mask()),
43            payload,
44        }
45    }
46
47    pub fn new_binary(payload: Vec<u8>) -> Self {
48        WsFrame {
49            fin: true,
50            opcode: OpCode::Binary,
51            mask: Some(generate_mask()),
52            payload,
53        }
54    }
55
56    pub fn new_close(code: u16, reason: &str) -> Self {
57        let mut payload = Vec::with_capacity(2 + reason.len());
58        payload.extend_from_slice(&code.to_be_bytes());
59        payload.extend_from_slice(reason.as_bytes());
60        WsFrame {
61            fin: true,
62            opcode: OpCode::Close,
63            mask: Some(generate_mask()),
64            payload,
65        }
66    }
67
68    pub fn new_pong(payload: Vec<u8>) -> Self {
69        WsFrame {
70            fin: true,
71            opcode: OpCode::Pong,
72            mask: Some(generate_mask()),
73            payload,
74        }
75    }
76
77    pub fn encode(&self) -> Vec<u8> {
78        let mut buf = Vec::new();
79        let b0 = if self.fin { 0x80u8 } else { 0u8 } | (self.opcode as u8);
80        buf.push(b0);
81
82        let masked = self.mask.is_some() as u8;
83        let payload_len = self.payload.len();
84
85        if payload_len < 126 {
86            buf.push((masked << 7) | payload_len as u8);
87        } else if payload_len <= 0xFFFF {
88            buf.push((masked << 7) | 126);
89            buf.extend_from_slice(&(payload_len as u16).to_be_bytes());
90        } else {
91            buf.push((masked << 7) | 127);
92            buf.extend_from_slice(&(payload_len as u64).to_be_bytes());
93        }
94
95        let mask_key = self.mask.unwrap_or([0u8; 4]);
96        if self.mask.is_some() {
97            buf.extend_from_slice(&mask_key);
98        }
99
100        let mut masked_payload = self.payload.clone();
101        if self.mask.is_some() {
102            for (i, byte) in masked_payload.iter_mut().enumerate() {
103                *byte ^= mask_key[i % 4];
104            }
105        }
106        buf.extend_from_slice(&masked_payload);
107        buf
108    }
109
110    pub fn parse_all(data: &[u8]) -> Result<Vec<WsFrame>, String> {
111        let mut frames = Vec::new();
112        let mut pos = 0;
113        while pos < data.len() {
114            let remaining = data.len() - pos;
115            if remaining < 2 {
116                break;
117            }
118            let b0 = data[pos];
119            let b1 = data[pos + 1];
120            let fin = (b0 & 0x80) != 0;
121            let opcode_val = b0 & 0x0F;
122            let opcode = OpCode::from_u4(opcode_val)
123                .ok_or_else(|| format!("unknown opcode: {opcode_val:#x}"))?;
124            let masked = (b1 & 0x80) != 0;
125            let mut payload_len = (b1 & 0x7F) as u64;
126
127            let mut header_len = 2;
128            if payload_len == 126 {
129                header_len += 2;
130            } else if payload_len == 127 {
131                header_len += 8;
132            }
133            let mask_len = if masked { 4 } else { 0 };
134            let total_header = header_len + mask_len;
135
136            if remaining < total_header {
137                break;
138            }
139
140            if payload_len == 126 {
141                payload_len = u64::from_be_bytes([0, 0, 0, 0, 0, 0, data[pos + 2], data[pos + 3]]);
142            } else if payload_len == 127 {
143                let mut arr = [0u8; 8];
144                arr.copy_from_slice(&data[pos + 2..pos + 10]);
145                payload_len = u64::from_be_bytes(arr);
146            }
147
148            let total_frame = total_header + payload_len as usize;
149            if remaining < total_frame {
150                break;
151            }
152
153            let mut payload = data[pos + total_header..pos + total_frame].to_vec();
154
155            let mask_key = if masked {
156                let mk = [
157                    data[pos + header_len],
158                    data[pos + header_len + 1],
159                    data[pos + header_len + 2],
160                    data[pos + header_len + 3],
161                ];
162                for (i, byte) in payload.iter_mut().enumerate() {
163                    *byte ^= mk[i % 4];
164                }
165                Some(mk)
166            } else {
167                None
168            };
169
170            frames.push(WsFrame {
171                fin,
172                opcode,
173                mask: mask_key,
174                payload,
175            });
176
177            pos += total_frame;
178        }
179        Ok(frames)
180    }
181}
182
183fn generate_mask() -> [u8; 4] {
184    use std::time::{SystemTime, UNIX_EPOCH};
185    let nanos = SystemTime::now()
186        .duration_since(UNIX_EPOCH)
187        .unwrap_or_default()
188        .as_nanos();
189    [
190        (nanos & 0xFF) as u8,
191        ((nanos >> 8) & 0xFF) as u8,
192        ((nanos >> 16) & 0xFF) as u8,
193        ((nanos >> 24) & 0xFF) as u8,
194    ]
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_encode_decode_text() {
203        let frame = WsFrame::new_text(b"Hello".to_vec());
204        let encoded = frame.encode();
205        let parsed = WsFrame::parse_all(&encoded).unwrap();
206        assert_eq!(parsed.len(), 1);
207        assert_eq!(parsed[0].opcode, OpCode::Text);
208        assert!(parsed[0].fin);
209        assert_eq!(parsed[0].payload, b"Hello");
210    }
211
212    #[test]
213    fn test_close_frame() {
214        let frame = WsFrame::new_close(1000, "Normal");
215        let encoded = frame.encode();
216        let parsed = WsFrame::parse_all(&encoded).unwrap();
217        assert_eq!(parsed.len(), 1);
218        assert_eq!(parsed[0].opcode, OpCode::Close);
219        assert_eq!(&parsed[0].payload[..2], &[0x03, 0xE8]);
220    }
221}