Skip to main content

jokoway_core/
websocket.rs

1use bytes::{BufMut, Bytes, BytesMut};
2use flate2::{Decompress, FlushDecompress};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum WsOpcode {
7    Continuation,
8    Text,
9    Binary,
10    Close,
11    Ping,
12    Pong,
13    Other(u8),
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct WsFrame {
18    pub fin: bool,
19    pub rsv1: bool,
20    pub rsv2: bool,
21    pub rsv3: bool,
22    pub opcode: WsOpcode,
23    pub payload: Bytes,
24}
25
26impl WsFrame {
27    pub fn is_text(&self) -> bool {
28        self.opcode == WsOpcode::Text
29    }
30
31    pub fn is_binary(&self) -> bool {
32        self.opcode == WsOpcode::Binary
33    }
34
35    pub fn is_continuation(&self) -> bool {
36        self.opcode == WsOpcode::Continuation
37    }
38
39    pub fn text(&self) -> Option<&str> {
40        if self.opcode == WsOpcode::Text {
41            std::str::from_utf8(&self.payload).ok()
42        } else {
43            None
44        }
45    }
46
47    pub fn set_text(&mut self, data: &str) {
48        self.opcode = WsOpcode::Text;
49        self.payload = Bytes::copy_from_slice(data.as_bytes());
50        self.rsv1 = false;
51        self.rsv2 = false;
52        self.rsv3 = false;
53    }
54
55    pub fn set_binary(&mut self, data: impl Into<Bytes>) {
56        self.opcode = WsOpcode::Binary;
57        self.payload = data.into();
58        self.rsv1 = false;
59        self.rsv2 = false;
60        self.rsv3 = false;
61    }
62
63    pub fn decompress_with(&self, decompressor: &mut Decompress) -> Option<Bytes> {
64        if !self.rsv1 {
65            return Some(self.payload.clone());
66        }
67
68        let mut data = self.payload.to_vec();
69        // Append 0x00 0x00 0xff 0xff tail if fin is true (end of message)
70        // permessage-deflate strips this tail from the output.
71        if self.fin {
72            data.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]);
73        }
74
75        let mut out = Vec::with_capacity(self.payload.len() * 3);
76        match decompressor.decompress_vec(&data, &mut out, FlushDecompress::Sync) {
77            Ok(_) => Some(Bytes::from(out)),
78            Err(e) => {
79                eprintln!("Decompression error: {:?}", e);
80                None
81            }
82        }
83    }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum WsParseResult {
88    Ok,
89    Incomplete,
90    Invalid,
91}
92
93pub fn parse_ws_frames(buffer: &mut BytesMut, frames: &mut Vec<WsFrame>) -> WsParseResult {
94    // Pre-allocate capacity for frames to reduce reallocations
95    if frames.capacity() == 0 {
96        frames.reserve(16);
97    }
98
99    loop {
100        let buffer_len = buffer.len();
101        if buffer_len < 2 {
102            return if frames.is_empty() {
103                WsParseResult::Incomplete
104            } else {
105                WsParseResult::Ok
106            };
107        }
108
109        let b0 = buffer[0];
110        let b1 = buffer[1];
111        let fin = (b0 & 0x80) != 0;
112        let rsv1 = (b0 & 0x40) != 0;
113        let rsv2 = (b0 & 0x20) != 0;
114        let rsv3 = (b0 & 0x10) != 0;
115        let opcode = match b0 & 0x0f {
116            0x0 => WsOpcode::Continuation,
117            0x1 => WsOpcode::Text,
118            0x2 => WsOpcode::Binary,
119            0x8 => WsOpcode::Close,
120            0x9 => WsOpcode::Ping,
121            0xA => WsOpcode::Pong,
122            v => WsOpcode::Other(v),
123        };
124        let masked = (b1 & 0x80) != 0;
125        let mut len = (b1 & 0x7f) as u64;
126        let mut offset = 2usize;
127
128        if len == 126 {
129            if buffer_len < 4 {
130                return if frames.is_empty() {
131                    WsParseResult::Incomplete
132                } else {
133                    WsParseResult::Ok
134                };
135            }
136            len = u16::from_be_bytes([buffer[2], buffer[3]]) as u64;
137            offset = 4;
138        } else if len == 127 {
139            if buffer_len < 10 {
140                return if frames.is_empty() {
141                    WsParseResult::Incomplete
142                } else {
143                    WsParseResult::Ok
144                };
145            }
146            len = u64::from_be_bytes([
147                buffer[2], buffer[3], buffer[4], buffer[5], buffer[6], buffer[7], buffer[8],
148                buffer[9],
149            ]);
150            offset = 10;
151        }
152
153        let len_usize = match usize::try_from(len) {
154            Ok(v) => v,
155            Err(_) => return WsParseResult::Invalid,
156        };
157
158        let mask_key = if masked {
159            if buffer_len < offset + 4 {
160                return if frames.is_empty() {
161                    WsParseResult::Incomplete
162                } else {
163                    WsParseResult::Ok
164                };
165            }
166            let key = [
167                buffer[offset],
168                buffer[offset + 1],
169                buffer[offset + 2],
170                buffer[offset + 3],
171            ];
172            offset += 4;
173            Some(key)
174        } else {
175            None
176        };
177
178        if buffer_len < offset + len_usize {
179            return if frames.is_empty() {
180                WsParseResult::Incomplete
181            } else {
182                WsParseResult::Ok
183            };
184        }
185
186        let mut data = buffer.split_to(offset + len_usize);
187        let _header = data.split_to(offset);
188
189        let payload = if let Some(key) = mask_key {
190            let mut payload_vec = data.to_vec();
191            apply_mask(&mut payload_vec, key);
192            Bytes::from(payload_vec)
193        } else {
194            data.freeze()
195        };
196
197        if opcode == WsOpcode::Continuation {
198            if let Some(last_frame) = frames.last_mut() {
199                let mut new_payload =
200                    BytesMut::with_capacity(last_frame.payload.len() + payload.len());
201                new_payload.extend_from_slice(&last_frame.payload);
202                new_payload.extend_from_slice(&payload);
203                last_frame.payload = new_payload.freeze();
204                last_frame.fin = fin;
205                continue;
206            } else {
207                return WsParseResult::Invalid;
208            }
209        }
210
211        frames.push(WsFrame {
212            fin,
213            rsv1,
214            rsv2,
215            rsv3,
216            opcode,
217            payload,
218        });
219    }
220}
221
222pub fn encode_ws_frame(frame: &WsFrame, mask_key: Option<[u8; 4]>) -> Vec<u8> {
223    let mut out = Vec::new();
224    encode_ws_frame_into(frame, mask_key, &mut out);
225    out
226}
227
228pub fn encode_ws_frame_into(frame: &WsFrame, mask_key: Option<[u8; 4]>, out: &mut impl BufMut) {
229    let opcode = match frame.opcode {
230        WsOpcode::Continuation => 0x0,
231        WsOpcode::Text => 0x1,
232        WsOpcode::Binary => 0x2,
233        WsOpcode::Close => 0x8,
234        WsOpcode::Ping => 0x9,
235        WsOpcode::Pong => 0xA,
236        WsOpcode::Other(v) => v & 0x0f,
237    };
238    let mut b0 = if frame.fin { 0x80 } else { 0x00 } | opcode;
239    if frame.rsv1 {
240        b0 |= 0x40;
241    }
242    if frame.rsv2 {
243        b0 |= 0x20;
244    }
245    if frame.rsv3 {
246        b0 |= 0x10;
247    }
248    out.put_u8(b0);
249
250    let masked = mask_key.is_some();
251    let payload_len = frame.payload.len() as u64;
252    if payload_len <= 125 {
253        out.put_u8((if masked { 0x80 } else { 0x00 }) | payload_len as u8);
254    } else if payload_len <= u16::MAX as u64 {
255        out.put_u8(if masked { 0x80 | 126 } else { 126 });
256        out.put_slice(&(payload_len as u16).to_be_bytes());
257    } else {
258        out.put_u8(if masked { 0x80 | 127 } else { 127 });
259        out.put_slice(&payload_len.to_be_bytes());
260    }
261
262    if let Some(key) = mask_key {
263        out.put_slice(&key);
264        let mut masked_payload = frame.payload.to_vec();
265        apply_mask(&mut masked_payload, key);
266        out.put_slice(&masked_payload);
267    } else {
268        out.put_slice(&frame.payload);
269    }
270}
271
272pub fn mask_key_from_time() -> [u8; 4] {
273    let nanos = SystemTime::now()
274        .duration_since(UNIX_EPOCH)
275        .map(|d| d.as_nanos())
276        .unwrap_or(0);
277    [
278        (nanos & 0xff) as u8,
279        ((nanos >> 8) & 0xff) as u8,
280        ((nanos >> 16) & 0xff) as u8,
281        ((nanos >> 24) & 0xff) as u8,
282    ]
283}
284
285fn apply_mask(payload: &mut [u8], key: [u8; 4]) {
286    for (idx, byte) in payload.iter_mut().enumerate() {
287        *byte ^= key[idx % 4];
288    }
289}
290
291#[derive(Debug, Clone, Copy, PartialEq, Eq)]
292pub enum WebsocketDirection {
293    DownstreamToUpstream,
294    UpstreamToDownstream,
295}
296
297#[derive(Debug, Clone, PartialEq, Eq)]
298pub enum WebsocketError {
299    InvalidFrame,
300}
301
302#[derive(Debug, Clone, PartialEq, Eq)]
303pub enum WebsocketErrorAction {
304    PassThrough,
305    Drop,
306    Close(Option<Vec<u8>>),
307}
308
309#[derive(Debug, Clone, PartialEq, Eq)]
310pub enum WebsocketMessageAction {
311    Forward(WsFrame),
312    Drop,
313    Close(Option<Vec<u8>>),
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn parse_and_encode_text_frame() {
322        let frame = WsFrame {
323            fin: true,
324            rsv1: false,
325            rsv2: false,
326            rsv3: false,
327            opcode: WsOpcode::Text,
328            payload: Bytes::from_static(b"hello"),
329        };
330        let mask_key = [1, 2, 3, 4];
331        let encoded = encode_ws_frame(&frame, Some(mask_key));
332        let mut buffer = BytesMut::from(encoded.as_slice());
333        let mut frames = Vec::new();
334        let result = parse_ws_frames(&mut buffer, &mut frames);
335        assert_eq!(result, WsParseResult::Ok);
336        assert!(buffer.is_empty());
337        assert_eq!(frames.len(), 1);
338        assert_eq!(frames[0], frame);
339    }
340
341    #[test]
342    fn parse_incomplete_frame() {
343        let mut buffer = BytesMut::from(&[0x81, 0x85, 1, 2, 3][..]);
344        let mut frames = Vec::new();
345        let result = parse_ws_frames(&mut buffer, &mut frames);
346        assert_eq!(result, WsParseResult::Incomplete);
347        assert!(frames.is_empty());
348    }
349    #[test]
350    fn test_ws_frame_helpers() {
351        let mut frame = WsFrame {
352            fin: true,
353            rsv1: false,
354            rsv2: false,
355            rsv3: false,
356            opcode: WsOpcode::Binary,
357            payload: Bytes::from_static(b"binary"),
358        };
359
360        assert!(frame.is_binary());
361        assert!(!frame.is_text());
362        assert_eq!(frame.text(), None);
363
364        frame.set_text("hello");
365        assert!(frame.is_text());
366        assert!(!frame.is_binary());
367        assert_eq!(frame.text(), Some("hello"));
368        assert_eq!(frame.payload, Bytes::from_static(b"hello"));
369
370        frame.set_binary(vec![1, 2, 3]);
371        assert!(frame.is_binary());
372        assert_eq!(frame.payload, Bytes::from_static(b"\x01\x02\x03"));
373    }
374}