1use crate::connection::Id;
12use futures::{prelude::*, ready};
13use std::{fmt, io, pin::Pin, task::{Context, Poll}};
14use super::{Frame, header::{self, HeaderDecodeError}};
15
16#[derive(Debug)]
18pub(crate) struct Io<T> {
19 id: Id,
20 io: T,
21 state: ReadState,
22 max_body_len: usize
23}
24
25impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
26 pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self {
27 Io {
28 id,
29 io,
30 state: ReadState::Init,
31 max_body_len: max_frame_body_len
32 }
33 }
34
35 pub(crate) async fn send<A>(&mut self, frame: &Frame<A>) -> io::Result<()> {
36 let header = header::encode(&frame.header);
37 self.io.write_all(&header).await?;
38 self.io.write_all(&frame.body).await
39 }
40
41 pub(crate) async fn flush(&mut self) -> io::Result<()> {
42 self.io.flush().await
43 }
44
45 pub(crate) async fn close(&mut self) -> io::Result<()> {
46 self.io.close().await
47 }
48}
49
50enum ReadState {
52 Init,
54 Header {
56 offset: usize,
57 buffer: [u8; header::HEADER_SIZE]
58 },
59 Body {
61 header: header::Header<()>,
62 offset: usize,
63 buffer: Vec<u8>
64 }
65}
66
67impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
68 type Item = Result<Frame<()>, FrameDecodeError>;
69
70 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
71 let mut this = &mut *self;
72 loop {
73 log::trace!("{}: read: {:?}", this.id, this.state);
74 match this.state {
75 ReadState::Init => {
76 this.state = ReadState::Header {
77 offset: 0,
78 buffer: [0; header::HEADER_SIZE]
79 };
80 }
81 ReadState::Header { ref mut offset, ref mut buffer } => {
82 if *offset == header::HEADER_SIZE {
83 let header =
84 match header::decode(&buffer) {
85 Ok(hd) => hd,
86 Err(e) => return Poll::Ready(Some(Err(e.into())))
87 };
88
89 log::trace!("{}: read: {}", this.id, header);
90
91 if header.tag() != header::Tag::Data {
92 this.state = ReadState::Init;
93 return Poll::Ready(Some(Ok(Frame::new(header))))
94 }
95
96 let body_len = header.len().val() as usize;
97
98 if body_len > this.max_body_len {
99 return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(body_len))))
100 }
101
102 this.state = ReadState::Body {
103 header,
104 offset: 0,
105 buffer: vec![0; body_len]
106 };
107
108 continue
109 }
110
111 let buf = &mut buffer[*offset .. header::HEADER_SIZE];
112 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
113 0 => {
114 if *offset == 0 {
115 return Poll::Ready(None)
116 }
117 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
118 return Poll::Ready(Some(Err(e)))
119 }
120 n => *offset += n
121 }
122 }
123 ReadState::Body { ref header, ref mut offset, ref mut buffer } => {
124 let body_len = header.len().val() as usize;
125
126 if *offset == body_len {
127 let h = header.clone();
128 let v = std::mem::take(buffer);
129 this.state = ReadState::Init;
130 return Poll::Ready(Some(Ok(Frame { header: h, body: v })))
131 }
132
133 let buf = &mut buffer[*offset .. body_len];
134 match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
135 0 => {
136 let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
137 return Poll::Ready(Some(Err(e)))
138 }
139 n => *offset += n
140 }
141 }
142 }
143 }
144 }
145}
146
147impl fmt::Debug for ReadState {
148 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149 match self {
150 ReadState::Init => {
151 f.write_str("(ReadState::Init)")
152 }
153 ReadState::Header { offset, .. } => {
154 write!(f, "(ReadState::Header {})", offset)
155 }
156 ReadState::Body { header, offset, buffer } => {
157 write!(f, "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
158 header,
159 offset,
160 buffer.len())
161 }
162 }
163 }
164}
165
166#[non_exhaustive]
168#[derive(Debug)]
169pub enum FrameDecodeError {
170 Io(io::Error),
172 Header(HeaderDecodeError),
174 FrameTooLarge(usize)
176}
177
178impl std::fmt::Display for FrameDecodeError {
179 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
180 match self {
181 FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e),
182 FrameDecodeError::Header(e) => write!(f, "decode error: {}", e),
183 FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n)
184 }
185 }
186}
187
188impl std::error::Error for FrameDecodeError {
189 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
190 match self {
191 FrameDecodeError::Io(e) => Some(e),
192 FrameDecodeError::Header(e) => Some(e),
193 FrameDecodeError::FrameTooLarge(_) => None
194 }
195 }
196}
197
198impl From<std::io::Error> for FrameDecodeError {
199 fn from(e: std::io::Error) -> Self {
200 FrameDecodeError::Io(e)
201 }
202}
203
204impl From<HeaderDecodeError> for FrameDecodeError {
205 fn from(e: HeaderDecodeError) -> Self {
206 FrameDecodeError::Header(e)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use quickcheck::{Arbitrary, Gen, QuickCheck};
213 use rand::RngCore;
214 use super::*;
215
216 impl Arbitrary for Frame<()> {
217 fn arbitrary(g: &mut Gen) -> Self {
218 let mut header: header::Header<()> = Arbitrary::arbitrary(g);
219 let body =
220 if header.tag() == header::Tag::Data {
221 header.set_len(header.len().val() % 4096);
222 let mut b = vec![0; header.len().val() as usize];
223 rand::thread_rng().fill_bytes(&mut b);
224 b
225 } else {
226 Vec::new()
227 };
228 Frame { header, body }
229 }
230 }
231
232 #[test]
233 fn encode_decode_identity() {
234 fn property(f: Frame<()>) -> bool {
235 futures::executor::block_on(async move {
236 let id = crate::connection::Id::random();
237 let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len());
238 if io.send(&f).await.is_err() {
239 return false
240 }
241 if io.flush().await.is_err() {
242 return false
243 }
244 io.io.set_position(0);
245 if let Ok(Some(x)) = io.try_next().await {
246 x == f
247 } else {
248 false
249 }
250 })
251 }
252
253 QuickCheck::new()
254 .tests(10_000)
255 .quickcheck(property as fn(Frame<()>) -> bool)
256 }
257}
258