1use bytes::{BufMut, Bytes, BytesMut};
2use flate2::{Decompress, FlushDecompress};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum WsOpcode {
7 Continuation,
8 Text,
9 Binary,
10 Close,
11 Ping,
12 Pong,
13 Other(u8),
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct WsFrame {
18 pub fin: bool,
19 pub rsv1: bool,
20 pub rsv2: bool,
21 pub rsv3: bool,
22 pub opcode: WsOpcode,
23 pub payload: Bytes,
24}
25
26impl WsFrame {
27 pub fn is_text(&self) -> bool {
28 self.opcode == WsOpcode::Text
29 }
30
31 pub fn is_binary(&self) -> bool {
32 self.opcode == WsOpcode::Binary
33 }
34
35 pub fn is_continuation(&self) -> bool {
36 self.opcode == WsOpcode::Continuation
37 }
38
39 pub fn text(&self) -> Option<&str> {
40 if self.opcode == WsOpcode::Text {
41 std::str::from_utf8(&self.payload).ok()
42 } else {
43 None
44 }
45 }
46
47 pub fn set_text(&mut self, data: &str) {
48 self.opcode = WsOpcode::Text;
49 self.payload = Bytes::copy_from_slice(data.as_bytes());
50 self.rsv1 = false;
51 self.rsv2 = false;
52 self.rsv3 = false;
53 }
54
55 pub fn set_binary(&mut self, data: impl Into<Bytes>) {
56 self.opcode = WsOpcode::Binary;
57 self.payload = data.into();
58 self.rsv1 = false;
59 self.rsv2 = false;
60 self.rsv3 = false;
61 }
62
63 pub fn decompress_with(&self, decompressor: &mut Decompress) -> Option<Bytes> {
64 if !self.rsv1 {
65 return Some(self.payload.clone());
66 }
67
68 let mut data = self.payload.to_vec();
69 if self.fin {
72 data.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]);
73 }
74
75 let mut out = Vec::with_capacity(self.payload.len() * 3);
76 match decompressor.decompress_vec(&data, &mut out, FlushDecompress::Sync) {
77 Ok(_) => Some(Bytes::from(out)),
78 Err(e) => {
79 eprintln!("Decompression error: {:?}", e);
80 None
81 }
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum WsParseResult {
88 Ok,
89 Incomplete,
90 Invalid,
91}
92
93pub fn parse_ws_frames(buffer: &mut BytesMut, frames: &mut Vec<WsFrame>) -> WsParseResult {
94 if frames.capacity() == 0 {
96 frames.reserve(16);
97 }
98
99 loop {
100 let buffer_len = buffer.len();
101 if buffer_len < 2 {
102 return if frames.is_empty() {
103 WsParseResult::Incomplete
104 } else {
105 WsParseResult::Ok
106 };
107 }
108
109 let b0 = buffer[0];
110 let b1 = buffer[1];
111 let fin = (b0 & 0x80) != 0;
112 let rsv1 = (b0 & 0x40) != 0;
113 let rsv2 = (b0 & 0x20) != 0;
114 let rsv3 = (b0 & 0x10) != 0;
115 let opcode = match b0 & 0x0f {
116 0x0 => WsOpcode::Continuation,
117 0x1 => WsOpcode::Text,
118 0x2 => WsOpcode::Binary,
119 0x8 => WsOpcode::Close,
120 0x9 => WsOpcode::Ping,
121 0xA => WsOpcode::Pong,
122 v => WsOpcode::Other(v),
123 };
124 let masked = (b1 & 0x80) != 0;
125 let mut len = (b1 & 0x7f) as u64;
126 let mut offset = 2usize;
127
128 if len == 126 {
129 if buffer_len < 4 {
130 return if frames.is_empty() {
131 WsParseResult::Incomplete
132 } else {
133 WsParseResult::Ok
134 };
135 }
136 len = u16::from_be_bytes([buffer[2], buffer[3]]) as u64;
137 offset = 4;
138 } else if len == 127 {
139 if buffer_len < 10 {
140 return if frames.is_empty() {
141 WsParseResult::Incomplete
142 } else {
143 WsParseResult::Ok
144 };
145 }
146 len = u64::from_be_bytes([
147 buffer[2], buffer[3], buffer[4], buffer[5], buffer[6], buffer[7], buffer[8],
148 buffer[9],
149 ]);
150 offset = 10;
151 }
152
153 let len_usize = match usize::try_from(len) {
154 Ok(v) => v,
155 Err(_) => return WsParseResult::Invalid,
156 };
157
158 let mask_key = if masked {
159 if buffer_len < offset + 4 {
160 return if frames.is_empty() {
161 WsParseResult::Incomplete
162 } else {
163 WsParseResult::Ok
164 };
165 }
166 let key = [
167 buffer[offset],
168 buffer[offset + 1],
169 buffer[offset + 2],
170 buffer[offset + 3],
171 ];
172 offset += 4;
173 Some(key)
174 } else {
175 None
176 };
177
178 if buffer_len < offset + len_usize {
179 return if frames.is_empty() {
180 WsParseResult::Incomplete
181 } else {
182 WsParseResult::Ok
183 };
184 }
185
186 let mut data = buffer.split_to(offset + len_usize);
187 let _header = data.split_to(offset);
188
189 let payload = if let Some(key) = mask_key {
190 let mut payload_vec = data.to_vec();
191 apply_mask(&mut payload_vec, key);
192 Bytes::from(payload_vec)
193 } else {
194 data.freeze()
195 };
196
197 if opcode == WsOpcode::Continuation {
198 if let Some(last_frame) = frames.last_mut() {
199 let mut new_payload =
200 BytesMut::with_capacity(last_frame.payload.len() + payload.len());
201 new_payload.extend_from_slice(&last_frame.payload);
202 new_payload.extend_from_slice(&payload);
203 last_frame.payload = new_payload.freeze();
204 last_frame.fin = fin;
205 continue;
206 } else {
207 return WsParseResult::Invalid;
208 }
209 }
210
211 frames.push(WsFrame {
212 fin,
213 rsv1,
214 rsv2,
215 rsv3,
216 opcode,
217 payload,
218 });
219 }
220}
221
222pub fn encode_ws_frame(frame: &WsFrame, mask_key: Option<[u8; 4]>) -> Vec<u8> {
223 let mut out = Vec::new();
224 encode_ws_frame_into(frame, mask_key, &mut out);
225 out
226}
227
228pub fn encode_ws_frame_into(frame: &WsFrame, mask_key: Option<[u8; 4]>, out: &mut impl BufMut) {
229 let opcode = match frame.opcode {
230 WsOpcode::Continuation => 0x0,
231 WsOpcode::Text => 0x1,
232 WsOpcode::Binary => 0x2,
233 WsOpcode::Close => 0x8,
234 WsOpcode::Ping => 0x9,
235 WsOpcode::Pong => 0xA,
236 WsOpcode::Other(v) => v & 0x0f,
237 };
238 let mut b0 = if frame.fin { 0x80 } else { 0x00 } | opcode;
239 if frame.rsv1 {
240 b0 |= 0x40;
241 }
242 if frame.rsv2 {
243 b0 |= 0x20;
244 }
245 if frame.rsv3 {
246 b0 |= 0x10;
247 }
248 out.put_u8(b0);
249
250 let masked = mask_key.is_some();
251 let payload_len = frame.payload.len() as u64;
252 if payload_len <= 125 {
253 out.put_u8((if masked { 0x80 } else { 0x00 }) | payload_len as u8);
254 } else if payload_len <= u16::MAX as u64 {
255 out.put_u8(if masked { 0x80 | 126 } else { 126 });
256 out.put_slice(&(payload_len as u16).to_be_bytes());
257 } else {
258 out.put_u8(if masked { 0x80 | 127 } else { 127 });
259 out.put_slice(&payload_len.to_be_bytes());
260 }
261
262 if let Some(key) = mask_key {
263 out.put_slice(&key);
264 let mut masked_payload = frame.payload.to_vec();
265 apply_mask(&mut masked_payload, key);
266 out.put_slice(&masked_payload);
267 } else {
268 out.put_slice(&frame.payload);
269 }
270}
271
272pub fn mask_key_from_time() -> [u8; 4] {
273 let nanos = SystemTime::now()
274 .duration_since(UNIX_EPOCH)
275 .map(|d| d.as_nanos())
276 .unwrap_or(0);
277 [
278 (nanos & 0xff) as u8,
279 ((nanos >> 8) & 0xff) as u8,
280 ((nanos >> 16) & 0xff) as u8,
281 ((nanos >> 24) & 0xff) as u8,
282 ]
283}
284
285fn apply_mask(payload: &mut [u8], key: [u8; 4]) {
286 for (idx, byte) in payload.iter_mut().enumerate() {
287 *byte ^= key[idx % 4];
288 }
289}
290
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum WebsocketDirection {
293 DownstreamToUpstream,
294 UpstreamToDownstream,
295}
296
297#[derive(Debug, Clone, PartialEq, Eq)]
298pub enum WebsocketError {
299 InvalidFrame,
300}
301
302#[derive(Debug, Clone, PartialEq, Eq)]
303pub enum WebsocketErrorAction {
304 PassThrough,
305 Drop,
306 Close(Option<Vec<u8>>),
307}
308
309#[derive(Debug, Clone, PartialEq, Eq)]
310pub enum WebsocketMessageAction {
311 Forward(WsFrame),
312 Drop,
313 Close(Option<Vec<u8>>),
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn parse_and_encode_text_frame() {
322 let frame = WsFrame {
323 fin: true,
324 rsv1: false,
325 rsv2: false,
326 rsv3: false,
327 opcode: WsOpcode::Text,
328 payload: Bytes::from_static(b"hello"),
329 };
330 let mask_key = [1, 2, 3, 4];
331 let encoded = encode_ws_frame(&frame, Some(mask_key));
332 let mut buffer = BytesMut::from(encoded.as_slice());
333 let mut frames = Vec::new();
334 let result = parse_ws_frames(&mut buffer, &mut frames);
335 assert_eq!(result, WsParseResult::Ok);
336 assert!(buffer.is_empty());
337 assert_eq!(frames.len(), 1);
338 assert_eq!(frames[0], frame);
339 }
340
341 #[test]
342 fn parse_incomplete_frame() {
343 let mut buffer = BytesMut::from(&[0x81, 0x85, 1, 2, 3][..]);
344 let mut frames = Vec::new();
345 let result = parse_ws_frames(&mut buffer, &mut frames);
346 assert_eq!(result, WsParseResult::Incomplete);
347 assert!(frames.is_empty());
348 }
349 #[test]
350 fn test_ws_frame_helpers() {
351 let mut frame = WsFrame {
352 fin: true,
353 rsv1: false,
354 rsv2: false,
355 rsv3: false,
356 opcode: WsOpcode::Binary,
357 payload: Bytes::from_static(b"binary"),
358 };
359
360 assert!(frame.is_binary());
361 assert!(!frame.is_text());
362 assert_eq!(frame.text(), None);
363
364 frame.set_text("hello");
365 assert!(frame.is_text());
366 assert!(!frame.is_binary());
367 assert_eq!(frame.text(), Some("hello"));
368 assert_eq!(frame.payload, Bytes::from_static(b"hello"));
369
370 frame.set_binary(vec![1, 2, 3]);
371 assert!(frame.is_binary());
372 assert_eq!(frame.payload, Bytes::from_static(b"\x01\x02\x03"));
373 }
374}