async_ws/frame/
decode.rs

1use crate::frame::{
2    CloseBodyError, FrameHeadDecodeState, FrameHeadParseError, FramePayloadReaderState,
3    WsControlFrame, WsControlFrameKind, WsControlFramePayload, WsDataFrame, WsFrame, WsFrameKind,
4};
5use futures::prelude::*;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9#[derive(Debug)]
10pub struct FrameDecoder<T: AsyncRead + Unpin> {
11    transport: Option<T>,
12    state: FrameDecoderState,
13}
14
15impl<T: AsyncRead + Unpin> FrameDecoder<T> {
16    pub fn checkpoint(self) -> Option<(T, FrameDecoderState)> {
17        let (transport, state) = (self.transport, self.state);
18        transport.map(|transport| (transport, state))
19    }
20}
21
22#[derive(Debug)]
23pub enum FrameDecoderState {
24    Head(FrameHeadDecodeState),
25    ControlPayload {
26        frame: WsControlFrame,
27        reader: FramePayloadReaderState,
28    },
29}
30
31impl FrameDecoderState {
32    pub fn new() -> Self {
33        Self::Head(FrameHeadDecodeState::new())
34    }
35    pub fn restore<T: AsyncRead + Unpin>(self, transport: T) -> FrameDecoder<T> {
36        FrameDecoder {
37            transport: Some(transport),
38            state: self,
39        }
40    }
41    pub fn poll<T: AsyncRead + Unpin>(
42        &mut self,
43        transport: &mut T,
44        cx: &mut Context<'_>,
45    ) -> Poll<Result<WsFrame, FrameDecodeError>> {
46        loop {
47            match self {
48                FrameDecoderState::Head(state) => match state.poll(transport, cx) {
49                    Poll::Ready(Ok(frame_head)) => match frame_head.opcode.frame_kind() {
50                        WsFrameKind::Data(frame_kind) => {
51                            return Poll::Ready(Ok(WsFrame::Data(WsDataFrame {
52                                kind: frame_kind,
53                                fin: frame_head.fin,
54                                mask: frame_head.mask,
55                                payload_len: frame_head.payload_len,
56                            })))
57                        }
58                        WsFrameKind::Control(frame_kind) => {
59                            *self = FrameDecoderState::ControlPayload {
60                                frame: WsControlFrame {
61                                    kind: frame_kind,
62                                    payload: WsControlFramePayload {
63                                        len: 0,
64                                        buffer: [0u8; 125],
65                                    },
66                                },
67                                reader: FramePayloadReaderState::new(
68                                    frame_head.mask,
69                                    frame_head.payload_len,
70                                ),
71                            };
72                        }
73                    },
74                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
75                    Poll::Pending => return Poll::Pending,
76                },
77                FrameDecoderState::ControlPayload { frame, reader } => {
78                    let off = frame.payload.len as usize;
79                    match reader.poll_read(transport, cx, &mut frame.payload.buffer[off..]) {
80                        Poll::Ready(Ok(n)) => {
81                            if n == 0 {
82                                if frame.kind == WsControlFrameKind::Close {
83                                    if let Err(err) = frame.payload.close_body() {
84                                        return Poll::Ready(Err(err.into()));
85                                    }
86                                }
87                                return Poll::Ready(Ok(WsFrame::Control(*frame)));
88                            }
89                            frame.payload.len += n as u8
90                        }
91                        Poll::Ready(Err(err)) => return Poll::Ready(Err(err.into())),
92                        Poll::Pending => return Poll::Pending,
93                    }
94                }
95            }
96        }
97    }
98}
99
100impl<T: AsyncRead + Unpin> Future for FrameDecoder<T> {
101    type Output = Result<(T, WsFrame), FrameDecodeError>;
102
103    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
104        let mut transport = self.transport.take().unwrap();
105        match self.state.poll(&mut transport, cx) {
106            Poll::Ready(Ok(frame)) => Poll::Ready(Ok((transport, frame))),
107            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
108            Poll::Pending => {
109                self.transport = Some(transport);
110                Poll::Pending
111            }
112        }
113    }
114}
115
116#[derive(thiserror::Error, Debug)]
117pub enum FrameDecodeError {
118    #[error("io error: {0}")]
119    Io(#[from] std::io::Error),
120    #[error("parse error: {0}")]
121    ParseErr(#[from] FrameHeadParseError),
122    #[error("invalid close body: {0}")]
123    InvalidCloseBody(#[from] CloseBodyError),
124}