async_ws/frame/frame_payload/
decode.rs

1use crate::frame::payload_mask;
2use futures::prelude::*;
3use std::convert::TryFrom;
4use std::io;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8#[derive(Debug)]
9pub struct FramePayloadReader<T: AsyncRead + Unpin> {
10    transport: T,
11    state: FramePayloadReaderState,
12}
13
14impl<T: AsyncRead + Unpin> FramePayloadReader<T> {
15    pub fn into_inner(self) -> T {
16        self.transport
17    }
18    pub fn checkpoint(self) -> (T, FramePayloadReaderState) {
19        (self.transport, self.state)
20    }
21}
22
23#[derive(Debug)]
24pub struct FramePayloadReaderState {
25    mask: [u8; 4],
26    payload_len: u64,
27    completion: u64,
28}
29
30impl FramePayloadReaderState {
31    pub fn new(mask: [u8; 4], payload_len: u64) -> Self {
32        Self {
33            mask,
34            payload_len,
35            completion: 0,
36        }
37    }
38    pub fn restore<T: AsyncRead + Unpin>(self, transport: T) -> FramePayloadReader<T> {
39        FramePayloadReader {
40            transport,
41            state: self,
42        }
43    }
44    pub fn poll_read<T: AsyncRead + Unpin>(
45        &mut self,
46        transport: &mut T,
47        cx: &mut Context<'_>,
48        buf: &mut [u8],
49    ) -> Poll<io::Result<usize>> {
50        if self.payload_len <= self.completion || buf.len() == 0 {
51            return Poll::Ready(Ok(0));
52        }
53        let min = match usize::try_from(self.payload_len - self.completion) {
54            Ok(remainder) => remainder.min(buf.len()),
55            Err(_) => buf.len(),
56        };
57        match Pin::new(transport).poll_read(cx, &mut buf[0..min]) {
58            Poll::Ready(Ok(n)) => match n {
59                0 => Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof))),
60                n => {
61                    payload_mask(self.mask, self.completion as usize, buf);
62                    self.completion += n as u64;
63                    Poll::Ready(Ok(n))
64                }
65            },
66            p => p,
67        }
68    }
69    pub fn finished(&self) -> bool {
70        self.payload_len == self.completion
71    }
72}
73
74impl<T: AsyncRead + Unpin> AsyncRead for FramePayloadReader<T> {
75    fn poll_read(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78        buf: &mut [u8],
79    ) -> Poll<io::Result<usize>> {
80        let Self { transport, state } = self.get_mut();
81        state.poll_read(transport, cx, buf)
82    }
83}