1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
use crate::frame::payload_mask;
use futures::prelude::*;
use std::convert::TryFrom;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Debug)]
pub struct FramePayloadReader<T: AsyncRead + Unpin> {
transport: T,
state: FramePayloadReaderState,
}
impl<T: AsyncRead + Unpin> FramePayloadReader<T> {
pub fn into_inner(self) -> T {
self.transport
}
pub fn checkpoint(self) -> (T, FramePayloadReaderState) {
(self.transport, self.state)
}
}
#[derive(Debug)]
pub struct FramePayloadReaderState {
mask: [u8; 4],
payload_len: u64,
completion: u64,
}
impl FramePayloadReaderState {
pub fn new(mask: [u8; 4], payload_len: u64) -> Self {
Self {
mask,
payload_len,
completion: 0,
}
}
pub fn restore<T: AsyncRead + Unpin>(self, transport: T) -> FramePayloadReader<T> {
FramePayloadReader {
transport,
state: self,
}
}
pub fn poll_read<T: AsyncRead + Unpin>(
&mut self,
transport: &mut T,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if self.payload_len <= self.completion || buf.len() == 0 {
return Poll::Ready(Ok(0));
}
let min = match usize::try_from(self.payload_len - self.completion) {
Ok(remainder) => remainder.min(buf.len()),
Err(_) => buf.len(),
};
match Pin::new(transport).poll_read(cx, &mut buf[0..min]) {
Poll::Ready(Ok(n)) => match n {
0 => Poll::Ready(Err(io::Error::from(io::ErrorKind::UnexpectedEof))),
n => {
payload_mask(self.mask, self.completion as usize, buf);
self.completion += n as u64;
Poll::Ready(Ok(n))
}
},
p => p,
}
}
pub fn finished(&self) -> bool {
self.payload_len == self.completion
}
}
impl<T: AsyncRead + Unpin> AsyncRead for FramePayloadReader<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let Self { transport, state } = self.get_mut();
state.poll_read(transport, cx, buf)
}
}