1use crate::frame::{
2 CloseBodyError, FrameHeadDecodeState, FrameHeadParseError, FramePayloadReaderState,
3 WsControlFrame, WsControlFrameKind, WsControlFramePayload, WsDataFrame, WsFrame, WsFrameKind,
4};
5use futures::prelude::*;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9#[derive(Debug)]
10pub struct FrameDecoder<T: AsyncRead + Unpin> {
11 transport: Option<T>,
12 state: FrameDecoderState,
13}
14
15impl<T: AsyncRead + Unpin> FrameDecoder<T> {
16 pub fn checkpoint(self) -> Option<(T, FrameDecoderState)> {
17 let (transport, state) = (self.transport, self.state);
18 transport.map(|transport| (transport, state))
19 }
20}
21
22#[derive(Debug)]
23pub enum FrameDecoderState {
24 Head(FrameHeadDecodeState),
25 ControlPayload {
26 frame: WsControlFrame,
27 reader: FramePayloadReaderState,
28 },
29}
30
31impl FrameDecoderState {
32 pub fn new() -> Self {
33 Self::Head(FrameHeadDecodeState::new())
34 }
35 pub fn restore<T: AsyncRead + Unpin>(self, transport: T) -> FrameDecoder<T> {
36 FrameDecoder {
37 transport: Some(transport),
38 state: self,
39 }
40 }
41 pub fn poll<T: AsyncRead + Unpin>(
42 &mut self,
43 transport: &mut T,
44 cx: &mut Context<'_>,
45 ) -> Poll<Result<WsFrame, FrameDecodeError>> {
46 loop {
47 match self {
48 FrameDecoderState::Head(state) => match state.poll(transport, cx) {
49 Poll::Ready(Ok(frame_head)) => match frame_head.opcode.frame_kind() {
50 WsFrameKind::Data(frame_kind) => {
51 return Poll::Ready(Ok(WsFrame::Data(WsDataFrame {
52 kind: frame_kind,
53 fin: frame_head.fin,
54 mask: frame_head.mask,
55 payload_len: frame_head.payload_len,
56 })))
57 }
58 WsFrameKind::Control(frame_kind) => {
59 *self = FrameDecoderState::ControlPayload {
60 frame: WsControlFrame {
61 kind: frame_kind,
62 payload: WsControlFramePayload {
63 len: 0,
64 buffer: [0u8; 125],
65 },
66 },
67 reader: FramePayloadReaderState::new(
68 frame_head.mask,
69 frame_head.payload_len,
70 ),
71 };
72 }
73 },
74 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
75 Poll::Pending => return Poll::Pending,
76 },
77 FrameDecoderState::ControlPayload { frame, reader } => {
78 let off = frame.payload.len as usize;
79 match reader.poll_read(transport, cx, &mut frame.payload.buffer[off..]) {
80 Poll::Ready(Ok(n)) => {
81 if n == 0 {
82 if frame.kind == WsControlFrameKind::Close {
83 if let Err(err) = frame.payload.close_body() {
84 return Poll::Ready(Err(err.into()));
85 }
86 }
87 return Poll::Ready(Ok(WsFrame::Control(*frame)));
88 }
89 frame.payload.len += n as u8
90 }
91 Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
92 Poll::Pending => return Poll::Pending,
93 }
94 }
95 }
96 }
97 }
98}
99
100impl<T: AsyncRead + Unpin> Future for FrameDecoder<T> {
101 type Output = Result<(T, WsFrame), FrameDecodeError>;
102
103 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104 let mut transport = self.transport.take().unwrap();
105 match self.state.poll(&mut transport, cx) {
106 Poll::Ready(Ok(frame)) => Poll::Ready(Ok((transport, frame))),
107 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
108 Poll::Pending => {
109 self.transport = Some(transport);
110 Poll::Pending
111 }
112 }
113 }
114}
115
116#[derive(thiserror::Error, Debug)]
117pub enum FrameDecodeError {
118 #[error("io error: {0}")]
119 Io(#[from] std::io::Error),
120 #[error("parse error: {0}")]
121 ParseErr(#[from] FrameHeadParseError),
122 #[error("invalid close body: {0}")]
123 InvalidCloseBody(#[from] CloseBodyError),
124}