1use crate::{
3 byte::{be_u16, be_u64, const_take, take},
4 AndThenExt, Incomplete, MapExt, Pipe, Result as PResult,
5};
6use fatal_error::FatalError;
7use std::ops::Deref;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11pub enum OpCode {
12 Continuation,
14 Text,
16 Binary,
18 NonControl1,
20 NonControl2,
22 NonControl3,
24 NonControl4,
26 NonControl5,
28 Close,
30 Ping,
32 Pong,
34 Control1,
36 Control2,
38 Control3,
40 Control4,
42 Control5,
44 Other(u8),
46}
47
48#[derive(Clone, Copy, Debug, PartialEq, Eq)]
50pub struct InvalidOpCode(u8);
51
52impl OpCode {
53 pub fn validate(self) -> Result<OpCode, InvalidOpCode> {
55 match self {
56 OpCode::Other(x) => Err(InvalidOpCode(x)),
57 x => Ok(x),
58 }
59 }
60}
61
62impl From<u8> for OpCode {
63 fn from(x: u8) -> Self {
64 match x {
65 0 => OpCode::Continuation,
66 1 => OpCode::Text,
67 2 => OpCode::Binary,
68 3 => OpCode::NonControl1,
69 4 => OpCode::NonControl2,
70 5 => OpCode::NonControl3,
71 6 => OpCode::NonControl4,
72 7 => OpCode::NonControl5,
73 8 => OpCode::Close,
74 9 => OpCode::Ping,
75 10 => OpCode::Pong,
76 11 => OpCode::Control1,
77 12 => OpCode::Control2,
78 13 => OpCode::Control3,
79 14 => OpCode::Control4,
80 15 => OpCode::Control5,
81 x => OpCode::Other(x),
82 }
83 }
84}
85
86impl From<OpCode> for u8 {
87 fn from(x: OpCode) -> Self {
88 match x {
89 OpCode::Continuation => 0,
90 OpCode::Text => 1,
91 OpCode::Binary => 2,
92 OpCode::NonControl1 => 3,
93 OpCode::NonControl2 => 4,
94 OpCode::NonControl3 => 5,
95 OpCode::NonControl4 => 6,
96 OpCode::NonControl5 => 7,
97 OpCode::Close => 8,
98 OpCode::Ping => 9,
99 OpCode::Pong => 10,
100 OpCode::Control1 => 11,
101 OpCode::Control2 => 12,
102 OpCode::Control3 => 13,
103 OpCode::Control4 => 14,
104 OpCode::Control5 => 15,
105 OpCode::Other(x) => x,
106 }
107 }
108}
109
110impl std::ops::Deref for OpCode {
111 type Target = u8;
112
113 fn deref(&self) -> &Self::Target {
114 match self {
115 OpCode::Continuation => &0,
116 OpCode::Text => &1,
117 OpCode::Binary => &2,
118 OpCode::NonControl1 => &3,
119 OpCode::NonControl2 => &4,
120 OpCode::NonControl3 => &5,
121 OpCode::NonControl4 => &6,
122 OpCode::NonControl5 => &7,
123 OpCode::Close => &8,
124 OpCode::Ping => &9,
125 OpCode::Pong => &10,
126 OpCode::Control1 => &11,
127 OpCode::Control2 => &12,
128 OpCode::Control3 => &13,
129 OpCode::Control4 => &14,
130 OpCode::Control5 => &15,
131 OpCode::Other(x) => x,
132 }
133 }
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
138pub enum Size {
139 U8(u8),
141 U16(u16),
143 U64(u64),
145}
146
147impl Size {
148 pub fn first_byte(&self) -> u8 {
150 match self {
151 Size::U8(x) => *x,
152 Size::U16(_) => 126,
153 Size::U64(_) => 127,
154 }
155 }
156
157 pub fn final_size(self) -> Vec<u8> {
159 match self {
160 Size::U8(_) => vec![],
161 Size::U16(x) => x.to_be_bytes().to_vec(),
162 Size::U64(x) => x.to_be_bytes().to_vec(),
163 }
164 }
165}
166
167impl From<Size> for usize {
168 fn from(x: Size) -> Self {
169 match x {
170 Size::U8(ref v) => *v as usize,
171 Size::U16(ref v) => *v as usize,
172 Size::U64(ref v) => *v as usize,
173 }
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct Frame {
180 fin: bool,
181 rsv1: bool,
182 rsv2: bool,
183 rsv3: bool,
184 opcode: OpCode,
185 mask: Option<[u8; 4]>,
186 size: Size,
187 data: Vec<u8>,
188}
189
190impl Frame {
191 pub fn mask(mut self) -> Frame {
193 if let Some(mask) = &self.mask {
194 for (i, v) in self.data.iter_mut().enumerate() {
195 *v ^= mask[i % 4];
196 }
197 }
198 self
199 }
200
201 pub fn into_vec(self) -> Vec<u8> {
203 let b1 = ((self.fin as u8) << 7)
204 | ((self.rsv1 as u8) << 6)
205 | ((self.rsv2 as u8) << 5)
206 | ((self.rsv3 as u8) << 4)
207 | (u8::from(self.opcode) & 0x0F);
208 let b2 = ((self.mask.is_some() as u8) << 7) | (self.size.first_byte() & 0x7F);
209 let mut r = vec![b1, b2];
210 r.extend(self.size.final_size());
211 if let Some(ref mask) = self.mask {
212 r.extend(mask.iter());
213 }
214 r.extend(&self.data);
215 r
216 }
217}
218
219impl From<Frame> for Vec<u8> {
220 fn from(frame: Frame) -> Self { frame.into_vec() }
221}
222
223#[derive(Debug, Clone)]
224enum FrameState {
225 Masked(Frame),
226 UnMasked(Frame),
227}
228
229impl From<FrameState> for Frame {
230 fn from(x: FrameState) -> Self {
231 match x {
232 FrameState::Masked(x) => x,
233 FrameState::UnMasked(x) => x,
234 }
235 }
236}
237
238impl Deref for FrameState {
239 type Target = Frame;
240
241 fn deref(&self) -> &Self::Target {
242 match self {
243 FrameState::Masked(x) | FrameState::UnMasked(x) => x,
244 }
245 }
246}
247
248impl FrameState {
249 fn unmask(self) -> FrameState {
250 match self {
251 FrameState::Masked(frame) => FrameState::UnMasked(frame.mask()),
252 x @ FrameState::UnMasked(_) => x,
253 }
254 }
255
256 fn mask(self) -> FrameState {
257 match self {
258 FrameState::UnMasked(frame) => FrameState::Masked(frame.mask()),
259 x @ FrameState::Masked(_) => x,
260 }
261 }
262
263 pub fn into_frame(self) -> Frame {
264 match self {
265 FrameState::Masked(x) | FrameState::UnMasked(x) => x,
266 }
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct MaskedFrame(FrameState);
273
274impl MaskedFrame {
275 pub fn mask(self) -> MaskedFrame { MaskedFrame(self.0.mask()) }
277
278 pub fn unmask(self) -> MaskedFrame { MaskedFrame(self.0.unmask()) }
280
281 pub fn into_frame(self) -> Frame { self.0.into_frame() }
283}
284
285impl Deref for MaskedFrame {
286 type Target = Frame;
287
288 fn deref(&self) -> &Self::Target { &self.0 }
289}
290
291impl From<MaskedFrame> for Frame {
292 fn from(x: MaskedFrame) -> Self { x.0.into() }
293}
294
295#[derive(Debug, Clone, PartialEq, Eq, Copy)]
297pub struct InvalidFrameSize(u8);
298
299impl std::fmt::Display for InvalidFrameSize {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 write!(f, "InvalidFrameSize: {}", self.0)
302 }
303}
304
305impl std::error::Error for InvalidFrameSize {}
306
307#[derive(Debug, Clone, PartialEq, Eq)]
309pub enum FrameSizeError {
310 Incomplete(Incomplete),
312 InvalidSize(InvalidFrameSize),
314}
315
316impl From<Incomplete> for FrameSizeError {
317 fn from(value: Incomplete) -> Self { FrameSizeError::Incomplete(value) }
318}
319
320impl From<InvalidFrameSize> for FrameSizeError {
321 fn from(value: InvalidFrameSize) -> Self { FrameSizeError::InvalidSize(value) }
322}
323
324impl std::fmt::Display for FrameSizeError {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 match self {
327 FrameSizeError::Incomplete(x) => write!(f, "FrameSizeError: {x}"),
328 FrameSizeError::InvalidSize(x) => write!(f, "FrameSizeError: {x}"),
329 }
330 }
331}
332
333impl std::error::Error for FrameSizeError {
334 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
335 match self {
336 FrameSizeError::Incomplete(x) => Some(x),
337 FrameSizeError::InvalidSize(x) => Some(x),
338 }
339 }
340}
341
342fn parse_frame_size(buf: &[u8], head: u16) -> PResult<&[u8], (Size,), FrameSizeError> {
343 match (head as u8) & 0x7F {
344 x @ 0..=125 => Ok((buf, (Size::U8(x),))),
345 126 => be_u16().map1(Size::U16).apply(buf),
346 127 => be_u64().map1(Size::U64).apply(buf),
347 x => Err(FatalError::Error(InvalidFrameSize(x).into())),
348 }
349}
350
351pub fn frame<'a>() -> impl Pipe<&'a [u8], (Frame,), FrameSizeError> {
353 move |x: &'a [u8]| {
354 let (buf, (head,)) = be_u16().apply(x)?;
355 let (buf, (size, mask)) = { move |x| parse_frame_size(x, head) }
356 .ok_and_then(|i, (o,)| {
357 if head & 0x80 == 0x80 {
358 Ok(const_take::<4, _>().map(|x: [u8; 4]| (o, Some(x))).apply(i)?)
359 } else {
360 Ok((i, (o, None)))
361 }
362 })
363 .apply(buf)?;
364 let (buf, (data,)) = take(size.into()).apply(buf)?;
365 Ok((
366 buf,
367 (Frame {
368 fin: (head >> 8) & 0x80 == 0x80,
369 rsv1: (head >> 8) & 0x40 == 0x40,
370 rsv2: (head >> 8) & 0x20 == 0x20,
371 rsv3: (head >> 8) & 0x10 == 0x10,
372 opcode: (((head >> 8) as u8) & 0x0F).into(),
373 size,
374 mask,
375 data: data.to_vec(),
376 },),
377 ))
378 }
379}
380
381pub fn masked_frame<'a>() -> impl Pipe<&'a [u8], (MaskedFrame,), FrameSizeError> {
383 frame().map1(|x: Frame| MaskedFrame(FrameState::Masked(x)))
384}
385
386pub fn unmasked_frame<'a>() -> impl Pipe<&'a [u8], (MaskedFrame,), FrameSizeError> {
388 frame().map1(|x| MaskedFrame(FrameState::UnMasked(x)))
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::{Pipe, UnpackExt};
395
396 #[test]
397 fn rfc_tests() {
398 let (x, (f,)) =
400 unmasked_frame().apply(&[0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]).unwrap();
401 assert!(x.is_empty());
402 assert_eq!(f.data, b"Hello");
403 assert!(f.fin);
404 assert!(!f.rsv1);
405 assert!(!f.rsv2);
406 assert!(!f.rsv3);
407 assert_eq!(f.mask, None);
408 assert_eq!(f.opcode, OpCode::Text);
409 assert_eq!(f.size, Size::U8(5));
410 let v: Vec<u8> = f.into_frame().into();
411 assert_eq!(&v, &[0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]);
412 let (x, (f,)) = masked_frame()
413 .apply(&[0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58])
414 .unwrap();
415 let f = f.unmask();
416 assert!(x.is_empty());
417 assert_eq!(f.data, b"Hello");
418
419 let r = unmasked_frame().apply(&[0x01, 0x03, 0x48, 0x65, 0x6c]).unwrap();
420 assert_eq!(r.0, b"");
421 assert_eq!(r.1 .0.data, b"Hel");
422 assert!(!r.1 .0.fin);
423
424 let r = unmasked_frame().apply(&[0x80, 0x02, 0x6c, 0x6f]).unwrap();
425
426 assert_eq!(r.0, b"");
427 assert_eq!(r.1 .0.data, b"lo");
428 assert!(r.1 .0.fin);
429
430 let r =
431 unmasked_frame().apply(&[0x89, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f]).unwrap();
432
433 assert_eq!(r.0, b"");
434 assert_eq!(r.1 .0.data, b"Hello");
435 assert!(r.1 .0.fin);
436 assert_eq!(r.1 .0.opcode, OpCode::Ping);
437
438 let r = masked_frame()
439 .map1(MaskedFrame::unmask)
440 .unpack()
441 .apply(&[0x8a, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58])
442 .unwrap();
443 assert_eq!(r.0, b"");
444 assert_eq!(r.1 .0.data, b"Hello");
445 assert!(r.1 .0.fin);
446 assert_eq!(r.1 .0.opcode, OpCode::Pong);
447
448 let mut buf = [0u8; 260];
449 buf[0] = 0x82;
450 buf[1] = 0x7E;
451 buf[2] = 0x01;
452 let (r, (f,)) = unmasked_frame().apply(&buf).unwrap();
453 assert!(r.is_empty());
454 assert_eq!(f.size, Size::U16(256));
455 let v: Vec<u8> = f.into_frame().into();
456 assert_eq!(&v, &buf);
457 let mut buf = [0u8; 65546];
458 buf[0] = 0x82;
459 buf[1] = 0x7F;
460 buf[7] = 0x01;
461 let (r, (f,)) = unmasked_frame().apply(&buf).unwrap();
462 assert!(r.is_empty());
463 assert_eq!(f.size, Size::U64(65536));
464 let v: Vec<u8> = f.into_frame().into();
465 assert_eq!(&v, &buf);
466 }
467}