1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2pub enum OpCode {
3 Continuation = 0x0,
4 Text = 0x1,
5 Binary = 0x2,
6 Close = 0x8,
7 Ping = 0x9,
8 Pong = 0xA,
9}
10
11impl OpCode {
12 pub fn from_u4(v: u8) -> Option<Self> {
13 match v {
14 0x0 => Some(OpCode::Continuation),
15 0x1 => Some(OpCode::Text),
16 0x2 => Some(OpCode::Binary),
17 0x8 => Some(OpCode::Close),
18 0x9 => Some(OpCode::Ping),
19 0xA => Some(OpCode::Pong),
20 _ => None,
21 }
22 }
23
24 pub fn is_control(&self) -> bool {
25 matches!(self, OpCode::Close | OpCode::Ping | OpCode::Pong)
26 }
27}
28
29#[derive(Debug, Clone)]
30pub struct WsFrame {
31 pub fin: bool,
32 pub opcode: OpCode,
33 pub mask: Option<[u8; 4]>,
34 pub payload: Vec<u8>,
35}
36
37impl WsFrame {
38 pub fn new_text(payload: Vec<u8>) -> Self {
39 WsFrame {
40 fin: true,
41 opcode: OpCode::Text,
42 mask: Some(generate_mask()),
43 payload,
44 }
45 }
46
47 pub fn new_binary(payload: Vec<u8>) -> Self {
48 WsFrame {
49 fin: true,
50 opcode: OpCode::Binary,
51 mask: Some(generate_mask()),
52 payload,
53 }
54 }
55
56 pub fn new_close(code: u16, reason: &str) -> Self {
57 let mut payload = Vec::with_capacity(2 + reason.len());
58 payload.extend_from_slice(&code.to_be_bytes());
59 payload.extend_from_slice(reason.as_bytes());
60 WsFrame {
61 fin: true,
62 opcode: OpCode::Close,
63 mask: Some(generate_mask()),
64 payload,
65 }
66 }
67
68 pub fn new_pong(payload: Vec<u8>) -> Self {
69 WsFrame {
70 fin: true,
71 opcode: OpCode::Pong,
72 mask: Some(generate_mask()),
73 payload,
74 }
75 }
76
77 pub fn encode(&self) -> Vec<u8> {
78 let mut buf = Vec::new();
79 let b0 = if self.fin { 0x80u8 } else { 0u8 } | (self.opcode as u8);
80 buf.push(b0);
81
82 let masked = self.mask.is_some() as u8;
83 let payload_len = self.payload.len();
84
85 if payload_len < 126 {
86 buf.push((masked << 7) | payload_len as u8);
87 } else if payload_len <= 0xFFFF {
88 buf.push((masked << 7) | 126);
89 buf.extend_from_slice(&(payload_len as u16).to_be_bytes());
90 } else {
91 buf.push((masked << 7) | 127);
92 buf.extend_from_slice(&(payload_len as u64).to_be_bytes());
93 }
94
95 let mask_key = self.mask.unwrap_or([0u8; 4]);
96 if self.mask.is_some() {
97 buf.extend_from_slice(&mask_key);
98 }
99
100 let mut masked_payload = self.payload.clone();
101 if self.mask.is_some() {
102 for (i, byte) in masked_payload.iter_mut().enumerate() {
103 *byte ^= mask_key[i % 4];
104 }
105 }
106 buf.extend_from_slice(&masked_payload);
107 buf
108 }
109
110 pub fn parse_all(data: &[u8]) -> Result<Vec<WsFrame>, String> {
111 let mut frames = Vec::new();
112 let mut pos = 0;
113 while pos < data.len() {
114 let remaining = data.len() - pos;
115 if remaining < 2 {
116 break;
117 }
118 let b0 = data[pos];
119 let b1 = data[pos + 1];
120 let fin = (b0 & 0x80) != 0;
121 let opcode_val = b0 & 0x0F;
122 let opcode = OpCode::from_u4(opcode_val)
123 .ok_or_else(|| format!("unknown opcode: {opcode_val:#x}"))?;
124 let masked = (b1 & 0x80) != 0;
125 let mut payload_len = (b1 & 0x7F) as u64;
126
127 let mut header_len = 2;
128 if payload_len == 126 {
129 header_len += 2;
130 } else if payload_len == 127 {
131 header_len += 8;
132 }
133 let mask_len = if masked { 4 } else { 0 };
134 let total_header = header_len + mask_len;
135
136 if remaining < total_header {
137 break;
138 }
139
140 if payload_len == 126 {
141 payload_len = u64::from_be_bytes([0, 0, 0, 0, 0, 0, data[pos + 2], data[pos + 3]]);
142 } else if payload_len == 127 {
143 let mut arr = [0u8; 8];
144 arr.copy_from_slice(&data[pos + 2..pos + 10]);
145 payload_len = u64::from_be_bytes(arr);
146 }
147
148 let total_frame = total_header + payload_len as usize;
149 if remaining < total_frame {
150 break;
151 }
152
153 let mut payload = data[pos + total_header..pos + total_frame].to_vec();
154
155 let mask_key = if masked {
156 let mk = [
157 data[pos + header_len],
158 data[pos + header_len + 1],
159 data[pos + header_len + 2],
160 data[pos + header_len + 3],
161 ];
162 for (i, byte) in payload.iter_mut().enumerate() {
163 *byte ^= mk[i % 4];
164 }
165 Some(mk)
166 } else {
167 None
168 };
169
170 frames.push(WsFrame {
171 fin,
172 opcode,
173 mask: mask_key,
174 payload,
175 });
176
177 pos += total_frame;
178 }
179 Ok(frames)
180 }
181}
182
183fn generate_mask() -> [u8; 4] {
184 use std::time::{SystemTime, UNIX_EPOCH};
185 let nanos = SystemTime::now()
186 .duration_since(UNIX_EPOCH)
187 .unwrap_or_default()
188 .as_nanos();
189 [
190 (nanos & 0xFF) as u8,
191 ((nanos >> 8) & 0xFF) as u8,
192 ((nanos >> 16) & 0xFF) as u8,
193 ((nanos >> 24) & 0xFF) as u8,
194 ]
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_encode_decode_text() {
203 let frame = WsFrame::new_text(b"Hello".to_vec());
204 let encoded = frame.encode();
205 let parsed = WsFrame::parse_all(&encoded).unwrap();
206 assert_eq!(parsed.len(), 1);
207 assert_eq!(parsed[0].opcode, OpCode::Text);
208 assert!(parsed[0].fin);
209 assert_eq!(parsed[0].payload, b"Hello");
210 }
211
212 #[test]
213 fn test_close_frame() {
214 let frame = WsFrame::new_close(1000, "Normal");
215 let encoded = frame.encode();
216 let parsed = WsFrame::parse_all(&encoded).unwrap();
217 assert_eq!(parsed.len(), 1);
218 assert_eq!(parsed[0].opcode, OpCode::Close);
219 assert_eq!(&parsed[0].payload[..2], &[0x03, 0xE8]);
220 }
221}