pipe_chain/parser/
websocket.rs

1//! websocket parser
2use crate::{
3    byte::{be_u16, be_u64, const_take, take},
4    AndThenExt, Incomplete, MapExt, Pipe, Result as PResult,
5};
6use fatal_error::FatalError;
7use std::ops::Deref;
8
9/// Websocket frame opcode
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11pub enum OpCode {
12    /// Continuation frame
13    Continuation,
14    /// Text frame
15    Text,
16    /// Binary frame
17    Binary,
18    /// NonControl1 frame
19    NonControl1,
20    /// NonControl2 frame
21    NonControl2,
22    /// NonControl3 frame
23    NonControl3,
24    /// NonControl4 frame
25    NonControl4,
26    /// NonControl5 frame
27    NonControl5,
28    /// Close frame
29    Close,
30    /// Ping frame
31    Ping,
32    /// Pong frame
33    Pong,
34    /// Control1 frame
35    Control1,
36    /// Control2 frame
37    Control2,
38    /// Control3 frame
39    Control3,
40    /// Control4 frame
41    Control4,
42    /// Control5 frame
43    Control5,
44    /// Future proof not in the rfc
45    Other(u8),
46}
47
48/// Unknown opcode
49#[derive(Clone, Copy, Debug, PartialEq, Eq)]
50pub struct InvalidOpCode(u8);
51
52impl OpCode {
53    /// validates this opcode
54    pub fn validate(self) -> Result<OpCode, InvalidOpCode> {
55        match self {
56            OpCode::Other(x) => Err(InvalidOpCode(x)),
57            x => Ok(x),
58        }
59    }
60}
61
62impl From<u8> for OpCode {
63    fn from(x: u8) -> Self {
64        match x {
65            0 => OpCode::Continuation,
66            1 => OpCode::Text,
67            2 => OpCode::Binary,
68            3 => OpCode::NonControl1,
69            4 => OpCode::NonControl2,
70            5 => OpCode::NonControl3,
71            6 => OpCode::NonControl4,
72            7 => OpCode::NonControl5,
73            8 => OpCode::Close,
74            9 => OpCode::Ping,
75            10 => OpCode::Pong,
76            11 => OpCode::Control1,
77            12 => OpCode::Control2,
78            13 => OpCode::Control3,
79            14 => OpCode::Control4,
80            15 => OpCode::Control5,
81            x => OpCode::Other(x),
82        }
83    }
84}
85
86impl From<OpCode> for u8 {
87    fn from(x: OpCode) -> Self {
88        match x {
89            OpCode::Continuation => 0,
90            OpCode::Text => 1,
91            OpCode::Binary => 2,
92            OpCode::NonControl1 => 3,
93            OpCode::NonControl2 => 4,
94            OpCode::NonControl3 => 5,
95            OpCode::NonControl4 => 6,
96            OpCode::NonControl5 => 7,
97            OpCode::Close => 8,
98            OpCode::Ping => 9,
99            OpCode::Pong => 10,
100            OpCode::Control1 => 11,
101            OpCode::Control2 => 12,
102            OpCode::Control3 => 13,
103            OpCode::Control4 => 14,
104            OpCode::Control5 => 15,
105            OpCode::Other(x) => x,
106        }
107    }
108}
109
110impl std::ops::Deref for OpCode {
111    type Target = u8;
112
113    fn deref(&self) -> &Self::Target {
114        match self {
115            OpCode::Continuation => &0,
116            OpCode::Text => &1,
117            OpCode::Binary => &2,
118            OpCode::NonControl1 => &3,
119            OpCode::NonControl2 => &4,
120            OpCode::NonControl3 => &5,
121            OpCode::NonControl4 => &6,
122            OpCode::NonControl5 => &7,
123            OpCode::Close => &8,
124            OpCode::Ping => &9,
125            OpCode::Pong => &10,
126            OpCode::Control1 => &11,
127            OpCode::Control2 => &12,
128            OpCode::Control3 => &13,
129            OpCode::Control4 => &14,
130            OpCode::Control5 => &15,
131            OpCode::Other(x) => x,
132        }
133    }
134}
135
136/// Packet Size
137#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
138pub enum Size {
139    /// u8 packet size
140    U8(u8),
141    /// u16 packet size
142    U16(u16),
143    /// u64 packet size
144    U64(u64),
145}
146
147impl Size {
148    /// First byte of the packet size
149    pub fn first_byte(&self) -> u8 {
150        match self {
151            Size::U8(x) => *x,
152            Size::U16(_) => 126,
153            Size::U64(_) => 127,
154        }
155    }
156
157    /// remaining byte of the packet size
158    pub fn final_size(self) -> Vec<u8> {
159        match self {
160            Size::U8(_) => vec![],
161            Size::U16(x) => x.to_be_bytes().to_vec(),
162            Size::U64(x) => x.to_be_bytes().to_vec(),
163        }
164    }
165}
166
167impl From<Size> for usize {
168    fn from(x: Size) -> Self {
169        match x {
170            Size::U8(ref v) => *v as usize,
171            Size::U16(ref v) => *v as usize,
172            Size::U64(ref v) => *v as usize,
173        }
174    }
175}
176
177/// Websocket frame
178#[derive(Debug, Clone)]
179pub struct Frame {
180    fin:    bool,
181    rsv1:   bool,
182    rsv2:   bool,
183    rsv3:   bool,
184    opcode: OpCode,
185    mask:   Option<[u8; 4]>,
186    size:   Size,
187    data:   Vec<u8>,
188}
189
190impl Frame {
191    /// apply the mask on frame data
192    pub fn mask(mut self) -> Frame {
193        if let Some(mask) = &self.mask {
194            for (i, v) in self.data.iter_mut().enumerate() {
195                *v ^= mask[i % 4];
196            }
197        }
198        self
199    }
200
201    /// Transforms this frame into a bytes
202    pub fn into_vec(self) -> Vec<u8> {
203        let b1 = ((self.fin as u8) << 7)
204            | ((self.rsv1 as u8) << 6)
205            | ((self.rsv2 as u8) << 5)
206            | ((self.rsv3 as u8) << 4)
207            | (u8::from(self.opcode) & 0x0F);
208        let b2 = ((self.mask.is_some() as u8) << 7) | (self.size.first_byte() & 0x7F);
209        let mut r = vec![b1, b2];
210        r.extend(self.size.final_size());
211        if let Some(ref mask) = self.mask {
212            r.extend(mask.iter());
213        }
214        r.extend(&self.data);
215        r
216    }
217}
218
219impl From<Frame> for Vec<u8> {
220    fn from(frame: Frame) -> Self { frame.into_vec() }
221}
222
223#[derive(Debug, Clone)]
224enum FrameState {
225    Masked(Frame),
226    UnMasked(Frame),
227}
228
229impl From<FrameState> for Frame {
230    fn from(x: FrameState) -> Self {
231        match x {
232            FrameState::Masked(x) => x,
233            FrameState::UnMasked(x) => x,
234        }
235    }
236}
237
238impl Deref for FrameState {
239    type Target = Frame;
240
241    fn deref(&self) -> &Self::Target {
242        match self {
243            FrameState::Masked(x) | FrameState::UnMasked(x) => x,
244        }
245    }
246}
247
248impl FrameState {
249    fn unmask(self) -> FrameState {
250        match self {
251            FrameState::Masked(frame) => FrameState::UnMasked(frame.mask()),
252            x @ FrameState::UnMasked(_) => x,
253        }
254    }
255
256    fn mask(self) -> FrameState {
257        match self {
258            FrameState::UnMasked(frame) => FrameState::Masked(frame.mask()),
259            x @ FrameState::Masked(_) => x,
260        }
261    }
262
263    pub fn into_frame(self) -> Frame {
264        match self {
265            FrameState::Masked(x) | FrameState::UnMasked(x) => x,
266        }
267    }
268}
269
270///
271#[derive(Debug, Clone)]
272pub struct MaskedFrame(FrameState);
273
274impl MaskedFrame {
275    /// mask this frame
276    pub fn mask(self) -> MaskedFrame { MaskedFrame(self.0.mask()) }
277
278    /// unmask this frame
279    pub fn unmask(self) -> MaskedFrame { MaskedFrame(self.0.unmask()) }
280
281    /// returns this frame inner type
282    pub fn into_frame(self) -> Frame { self.0.into_frame() }
283}
284
285impl Deref for MaskedFrame {
286    type Target = Frame;
287
288    fn deref(&self) -> &Self::Target { &self.0 }
289}
290
291impl From<MaskedFrame> for Frame {
292    fn from(x: MaskedFrame) -> Self { x.0.into() }
293}
294
295/// Frame has an invalid size
296#[derive(Debug, Clone, PartialEq, Eq, Copy)]
297pub struct InvalidFrameSize(u8);
298
299impl std::fmt::Display for InvalidFrameSize {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        write!(f, "InvalidFrameSize: {}", self.0)
302    }
303}
304
305impl std::error::Error for InvalidFrameSize {}
306
307/// Error during frame size parsing
308#[derive(Debug, Clone, PartialEq, Eq)]
309pub enum FrameSizeError {
310    /// Frame size needs more input bytes
311    Incomplete(Incomplete),
312    /// Frame size is incorrect
313    InvalidSize(InvalidFrameSize),
314}
315
316impl From<Incomplete> for FrameSizeError {
317    fn from(value: Incomplete) -> Self { FrameSizeError::Incomplete(value) }
318}
319
320impl From<InvalidFrameSize> for FrameSizeError {
321    fn from(value: InvalidFrameSize) -> Self { FrameSizeError::InvalidSize(value) }
322}
323
324impl std::fmt::Display for FrameSizeError {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        match self {
327            FrameSizeError::Incomplete(x) => write!(f, "FrameSizeError: {x}"),
328            FrameSizeError::InvalidSize(x) => write!(f, "FrameSizeError: {x}"),
329        }
330    }
331}
332
333impl std::error::Error for FrameSizeError {
334    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
335        match self {
336            FrameSizeError::Incomplete(x) => Some(x),
337            FrameSizeError::InvalidSize(x) => Some(x),
338        }
339    }
340}
341
342fn parse_frame_size(buf: &[u8], head: u16) -> PResult<&[u8], (Size,), FrameSizeError> {
343    match (head as u8) & 0x7F {
344        x @ 0..=125 => Ok((buf, (Size::U8(x),))),
345        126 => be_u16().map1(Size::U16).apply(buf),
346        127 => be_u64().map1(Size::U64).apply(buf),
347        x => Err(FatalError::Error(InvalidFrameSize(x).into())),
348    }
349}
350
351/// returns a websocket parser
352pub fn frame<'a>() -> impl Pipe<&'a [u8], (Frame,), FrameSizeError> {
353    move |x: &'a [u8]| {
354        let (buf, (head,)) = be_u16().apply(x)?;
355        let (buf, (size, mask)) = { move |x| parse_frame_size(x, head) }
356            .ok_and_then(|i, (o,)| {
357                if head & 0x80 == 0x80 {
358                    Ok(const_take::<4, _>().map(|x: [u8; 4]| (o, Some(x))).apply(i)?)
359                } else {
360                    Ok((i, (o, None)))
361                }
362            })
363            .apply(buf)?;
364        let (buf, (data,)) = take(size.into()).apply(buf)?;
365        Ok((
366            buf,
367            (Frame {
368                fin: (head >> 8) & 0x80 == 0x80,
369                rsv1: (head >> 8) & 0x40 == 0x40,
370                rsv2: (head >> 8) & 0x20 == 0x20,
371                rsv3: (head >> 8) & 0x10 == 0x10,
372                opcode: (((head >> 8) as u8) & 0x0F).into(),
373                size,
374                mask,
375                data: data.to_vec(),
376            },),
377        ))
378    }
379}
380
381/// returns a parser of masked frames
382pub fn masked_frame<'a>() -> impl Pipe<&'a [u8], (MaskedFrame,), FrameSizeError> {
383    frame().map1(|x: Frame| MaskedFrame(FrameState::Masked(x)))
384}
385
386/// returns a parser of unmasked frames
387pub fn unmasked_frame<'a>() -> impl Pipe<&'a [u8], (MaskedFrame,), FrameSizeError> {
388    frame().map1(|x| MaskedFrame(FrameState::UnMasked(x)))
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::{Pipe, UnpackExt};
395
396    #[test]
397    fn rfc_tests() {
398        // TODO: complete with: https://datatracker.ietf.org/doc/html/rfc6455#section-5.7
399        let (x, (f,)) =
400            unmasked_frame().apply(&[0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]).unwrap();
401        assert!(x.is_empty());
402        assert_eq!(f.data, b"Hello");
403        assert!(f.fin);
404        assert!(!f.rsv1);
405        assert!(!f.rsv2);
406        assert!(!f.rsv3);
407        assert_eq!(f.mask, None);
408        assert_eq!(f.opcode, OpCode::Text);
409        assert_eq!(f.size, Size::U8(5));
410        let v: Vec<u8> = f.into_frame().into();
411        assert_eq!(&v, &[0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
412        let (x, (f,)) = masked_frame()
413            .apply(&[0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58])
414            .unwrap();
415        let f = f.unmask();
416        assert!(x.is_empty());
417        assert_eq!(f.data, b"Hello");
418
419        let r = unmasked_frame().apply(&[0x01, 0x03, 0x48, 0x65, 0x6c]).unwrap();
420        assert_eq!(r.0, b"");
421        assert_eq!(r.1 .0.data, b"Hel");
422        assert!(!r.1 .0.fin);
423
424        let r = unmasked_frame().apply(&[0x80, 0x02, 0x6c, 0x6f]).unwrap();
425
426        assert_eq!(r.0, b"");
427        assert_eq!(r.1 .0.data, b"lo");
428        assert!(r.1 .0.fin);
429
430        let r =
431            unmasked_frame().apply(&[0x89, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]).unwrap();
432
433        assert_eq!(r.0, b"");
434        assert_eq!(r.1 .0.data, b"Hello");
435        assert!(r.1 .0.fin);
436        assert_eq!(r.1 .0.opcode, OpCode::Ping);
437
438        let r = masked_frame()
439            .map1(MaskedFrame::unmask)
440            .unpack()
441            .apply(&[0x8a, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58])
442            .unwrap();
443        assert_eq!(r.0, b"");
444        assert_eq!(r.1 .0.data, b"Hello");
445        assert!(r.1 .0.fin);
446        assert_eq!(r.1 .0.opcode, OpCode::Pong);
447
448        let mut buf = [0u8; 260];
449        buf[0] = 0x82;
450        buf[1] = 0x7E;
451        buf[2] = 0x01;
452        let (r, (f,)) = unmasked_frame().apply(&buf).unwrap();
453        assert!(r.is_empty());
454        assert_eq!(f.size, Size::U16(256));
455        let v: Vec<u8> = f.into_frame().into();
456        assert_eq!(&v, &buf);
457        let mut buf = [0u8; 65546];
458        buf[0] = 0x82;
459        buf[1] = 0x7F;
460        buf[7] = 0x01;
461        let (r, (f,)) = unmasked_frame().apply(&buf).unwrap();
462        assert!(r.is_empty());
463        assert_eq!(f.size, Size::U64(65536));
464        let v: Vec<u8> = f.into_frame().into();
465        assert_eq!(&v, &buf);
466    }
467}