remux/frame/
io.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use crate::connection::Id;
12use futures::{prelude::*, ready};
13use std::{fmt, io, pin::Pin, task::{Context, Poll}};
14use super::{Frame, header::{self, HeaderDecodeError}};
15
16/// A [`Stream`] and writer of [`Frame`] values.
17#[derive(Debug)]
18pub(crate) struct Io<T> {
19    id: Id,
20    io: T,
21    state: ReadState,
22    max_body_len: usize
23}
24
25impl<T: AsyncRead + AsyncWrite + Unpin> Io<T> {
26    pub(crate) fn new(id: Id, io: T, max_frame_body_len: usize) -> Self {
27        Io {
28            id,
29            io,
30            state: ReadState::Init,
31            max_body_len: max_frame_body_len
32        }
33    }
34
35    pub(crate) async fn send<A>(&mut self, frame: &Frame<A>) -> io::Result<()> {
36        let header = header::encode(&frame.header);
37        self.io.write_all(&header).await?;
38        self.io.write_all(&frame.body).await
39    }
40
41    pub(crate) async fn flush(&mut self) -> io::Result<()> {
42        self.io.flush().await
43    }
44
45    pub(crate) async fn close(&mut self) -> io::Result<()> {
46        self.io.close().await
47    }
48}
49
50/// The stages of reading a new `Frame`.
51enum ReadState {
52    /// Initial reading state.
53    Init,
54    /// Reading the frame header.
55    Header {
56        offset: usize,
57        buffer: [u8; header::HEADER_SIZE]
58    },
59    /// Reading the frame body.
60    Body {
61        header: header::Header<()>,
62        offset: usize,
63        buffer: Vec<u8>
64    }
65}
66
67impl<T: AsyncRead + AsyncWrite + Unpin> Stream for Io<T> {
68    type Item = Result<Frame<()>, FrameDecodeError>;
69
70    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
71        let mut this = &mut *self;
72        loop {
73            log::trace!("{}: read: {:?}", this.id, this.state);
74            match this.state {
75                ReadState::Init => {
76                    this.state = ReadState::Header {
77                        offset: 0,
78                        buffer: [0; header::HEADER_SIZE]
79                    };
80                }
81                ReadState::Header { ref mut offset, ref mut buffer } => {
82                    if *offset == header::HEADER_SIZE {
83                        let header =
84                            match header::decode(&buffer) {
85                                Ok(hd) => hd,
86                                Err(e) => return Poll::Ready(Some(Err(e.into())))
87                            };
88
89                        log::trace!("{}: read: {}", this.id, header);
90
91                        if header.tag() != header::Tag::Data {
92                            this.state = ReadState::Init;
93                            return Poll::Ready(Some(Ok(Frame::new(header))))
94                        }
95
96                        let body_len = header.len().val() as usize;
97
98                        if body_len > this.max_body_len {
99                            return Poll::Ready(Some(Err(FrameDecodeError::FrameTooLarge(body_len))))
100                        }
101
102                        this.state = ReadState::Body {
103                            header,
104                            offset: 0,
105                            buffer: vec![0; body_len]
106                        };
107
108                        continue
109                    }
110
111                    let buf = &mut buffer[*offset .. header::HEADER_SIZE];
112                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
113                        0 => {
114                            if *offset == 0 {
115                                return Poll::Ready(None)
116                            }
117                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
118                            return Poll::Ready(Some(Err(e)))
119                        }
120                        n => *offset += n
121                    }
122                }
123                ReadState::Body { ref header, ref mut offset, ref mut buffer } => {
124                    let body_len = header.len().val() as usize;
125
126                    if *offset == body_len {
127                        let h = header.clone();
128                        let v = std::mem::take(buffer);
129                        this.state = ReadState::Init;
130                        return Poll::Ready(Some(Ok(Frame { header: h, body: v })))
131                    }
132
133                    let buf = &mut buffer[*offset .. body_len];
134                    match ready!(Pin::new(&mut this.io).poll_read(cx, buf))? {
135                        0 => {
136                            let e = FrameDecodeError::Io(io::ErrorKind::UnexpectedEof.into());
137                            return Poll::Ready(Some(Err(e)))
138                        }
139                        n => *offset += n
140                    }
141                }
142            }
143        }
144    }
145}
146
147impl fmt::Debug for ReadState {
148    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
149        match self {
150            ReadState::Init => {
151                f.write_str("(ReadState::Init)")
152            }
153            ReadState::Header { offset, .. } => {
154                write!(f, "(ReadState::Header {})", offset)
155            }
156            ReadState::Body { header, offset, buffer } => {
157                write!(f, "(ReadState::Body (header {}) (offset {}) (buffer-len {}))",
158                    header,
159                    offset,
160                    buffer.len())
161            }
162        }
163    }
164}
165
166/// Possible errors while decoding a message frame.
167#[non_exhaustive]
168#[derive(Debug)]
169pub enum FrameDecodeError {
170    /// An I/O error.
171    Io(io::Error),
172    /// Decoding the frame header failed.
173    Header(HeaderDecodeError),
174    /// A data frame body length is larger than the configured maximum.
175    FrameTooLarge(usize)
176}
177
178impl std::fmt::Display for FrameDecodeError {
179    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
180        match self {
181            FrameDecodeError::Io(e) => write!(f, "i/o error: {}", e),
182            FrameDecodeError::Header(e) => write!(f, "decode error: {}", e),
183            FrameDecodeError::FrameTooLarge(n) => write!(f, "frame body is too large ({})", n)
184        }
185    }
186}
187
188impl std::error::Error for FrameDecodeError {
189    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
190        match self {
191            FrameDecodeError::Io(e) => Some(e),
192            FrameDecodeError::Header(e) => Some(e),
193            FrameDecodeError::FrameTooLarge(_) => None
194        }
195    }
196}
197
198impl From<std::io::Error> for FrameDecodeError {
199    fn from(e: std::io::Error) -> Self {
200        FrameDecodeError::Io(e)
201    }
202}
203
204impl From<HeaderDecodeError> for FrameDecodeError {
205    fn from(e: HeaderDecodeError) -> Self {
206        FrameDecodeError::Header(e)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use quickcheck::{Arbitrary, Gen, QuickCheck};
213    use rand::RngCore;
214    use super::*;
215
216    impl Arbitrary for Frame<()> {
217        fn arbitrary(g: &mut Gen) -> Self {
218            let mut header: header::Header<()> = Arbitrary::arbitrary(g);
219            let body =
220                if header.tag() == header::Tag::Data {
221                    header.set_len(header.len().val() % 4096);
222                    let mut b = vec![0; header.len().val() as usize];
223                    rand::thread_rng().fill_bytes(&mut b);
224                    b
225                } else {
226                    Vec::new()
227                };
228            Frame { header, body }
229        }
230    }
231
232    #[test]
233    fn encode_decode_identity() {
234        fn property(f: Frame<()>) -> bool {
235            futures::executor::block_on(async move {
236                let id = crate::connection::Id::random();
237                let mut io = Io::new(id, futures::io::Cursor::new(Vec::new()), f.body.len());
238                if io.send(&f).await.is_err() {
239                    return false
240                }
241                if io.flush().await.is_err() {
242                    return false
243                }
244                io.io.set_position(0);
245                if let Ok(Some(x)) = io.try_next().await {
246                    x == f
247                } else {
248                    false
249                }
250            })
251        }
252
253        QuickCheck::new()
254            .tests(10_000)
255            .quickcheck(property as fn(Frame<()>) -> bool)
256    }
257}
258