async_codec/
framed_std.rs

1//! The `Framed` type for `futures::io::{AsyncRead, AsyncWrite}` streams
2
3use std::{
4    error::Error,
5    fmt::{
6        self,
7        Display,
8    },
9    io,
10    pin::Pin,
11};
12
13use bytes::{
14    buf::UninitSlice,
15    Buf,
16    BufMut,
17    BytesMut,
18};
19
20use futures::{
21    io::{
22        AsyncRead,
23        AsyncWrite,
24    },
25    prelude::*,
26    ready,
27    task::{
28        Context,
29        Poll,
30    },
31};
32
33use super::*;
34
35/// A wrapper around a byte stream that uses an `Encode + Decode` type to
36/// produce a `Sink + Stream`
37pub struct Framed<S, C> {
38    stream: S,
39    codec: C,
40    read_buf: BytesMut,
41    write_buf: BytesMut,
42}
43
44impl<S, C: Encode + Decode> Framed<S, C> {
45    /// Create a new framed stream from a byte stream and a codec.
46    pub fn new(stream: S, codec: C) -> Self {
47        Framed {
48            stream,
49            codec,
50            read_buf: BytesMut::default(),
51            write_buf: BytesMut::default(),
52        }
53    }
54}
55
56/// Errors arising from reading a frame
57#[derive(Debug)]
58pub enum ReadFrameError<E> {
59    /// There was an error in the underlying stream
60    Io(io::Error),
61    /// There was an error decoding the frame
62    Decode(E),
63}
64
65impl<E: Error + Send + Sync + 'static> From<ReadFrameError<E>> for io::Error {
66    fn from(other: ReadFrameError<E>) -> io::Error {
67        match other {
68            ReadFrameError::Decode(err) => io::Error::new(io::ErrorKind::InvalidData, err),
69            ReadFrameError::Io(err) => err,
70        }
71    }
72}
73
74impl<E: Display> Display for ReadFrameError<E> {
75    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
76        match self {
77            ReadFrameError::Io(e) => write!(f, "error reading from stream: {}", e),
78            ReadFrameError::Decode(e) => write!(f, "error decoding frame: {}", e),
79        }
80    }
81}
82
83impl<E> Error for ReadFrameError<E>
84where
85    E: Error + 'static,
86{
87    fn source(&self) -> Option<&(dyn Error + 'static)> {
88        match self {
89            ReadFrameError::Io(ref e) => Some(e),
90            ReadFrameError::Decode(ref e) => Some(e),
91        }
92    }
93}
94
95impl<S, C> Stream for Framed<S, C>
96where
97    C: Decode,
98    S: AsyncRead,
99{
100    type Item = Result<C::Item, ReadFrameError<C::Error>>;
101
102    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
103        let (mut stream, codec, read_buf) = unsafe {
104            let this = self.get_unchecked_mut();
105            (
106                Pin::new_unchecked(&mut this.stream),
107                &mut this.codec,
108                &mut this.read_buf,
109            )
110        };
111        loop {
112            let empty = read_buf.is_empty();
113            if !empty {
114                let (consumed, decode_res) = codec.decode(read_buf);
115                read_buf.advance(consumed);
116                match decode_res {
117                    DecodeResult::Ok(value) => {
118                        return Poll::Ready(Some(Ok(value)));
119                    }
120                    DecodeResult::Err(e) => {
121                        return Poll::Ready(Some(Err(ReadFrameError::Decode(e))));
122                    }
123                    DecodeResult::UnexpectedEnd => {}
124                }
125            }
126
127            // Make sure there's at least one byte available to read into
128            read_buf.reserve(1);
129
130            // Safety: the buffer that we're reading into is immediately zeroed
131            // without reading it.
132            let eof = unsafe {
133                let n = {
134                    let b = zero_buf(read_buf.chunk_mut());
135                    match ready!(stream.as_mut().poll_read(cx, b)).map_err(ReadFrameError::Io) {
136                        Err(e) => return Poll::Ready(Some(Err(e))),
137                        Ok(n) => n,
138                    }
139                };
140
141                read_buf.advance_mut(n);
142
143                n == 0
144            };
145
146            if eof {
147                if empty {
148                    return Poll::Ready(None);
149                } else {
150                    return Poll::Ready(Some(Err(ReadFrameError::Io(
151                        io::ErrorKind::UnexpectedEof.into(),
152                    ))));
153                }
154            }
155        }
156    }
157}
158
159/// Errors arising from writing a frame to a stream
160#[derive(Debug)]
161pub enum WriteFrameError<E> {
162    /// An error in the underlying stream
163    Io(io::Error),
164    /// An error from encoding the frame
165    Encode(E),
166}
167
168impl<E: Error + Send + Sync + 'static> From<WriteFrameError<E>> for io::Error {
169    fn from(other: WriteFrameError<E>) -> io::Error {
170        match other {
171            WriteFrameError::Encode(err) => io::Error::new(io::ErrorKind::InvalidInput, err),
172            WriteFrameError::Io(err) => err,
173        }
174    }
175}
176
177impl<E: Display> Display for WriteFrameError<E> {
178    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
179        match self {
180            WriteFrameError::Io(e) => write!(f, "error writing to stream: {}", e),
181            WriteFrameError::Encode(e) => write!(f, "error encoding frame: {}", e),
182        }
183    }
184}
185
186impl<E> Error for WriteFrameError<E>
187where
188    E: Error + 'static,
189{
190    fn source(&self) -> Option<&(dyn Error + 'static)> {
191        match self {
192            WriteFrameError::Io(ref e) => Some(e),
193            WriteFrameError::Encode(ref e) => Some(e),
194        }
195    }
196}
197
198impl<S, C> Sink<C::Item> for Framed<S, C>
199where
200    C: Encode,
201    C::Error: std::fmt::Debug,
202    S: AsyncWrite,
203{
204    type Error = WriteFrameError<C::Error>;
205    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206        let (buffer, mut stream) = unsafe {
207            let this = self.get_unchecked_mut();
208            (&mut this.write_buf, Pin::new_unchecked(&mut this.stream))
209        };
210        loop {
211            if buffer.len() == 0 {
212                return Poll::Ready(Ok(()));
213            }
214            let written = ready!(stream
215                .as_mut()
216                .poll_write(cx, &buffer)
217                .map_err(WriteFrameError::Io))?;
218            if written == 0 {
219                return Poll::Ready(Err(WriteFrameError::Io(io::ErrorKind::WriteZero.into())));
220            }
221            buffer.advance(written);
222        }
223    }
224    fn start_send(self: Pin<&mut Self>, item: C::Item) -> Result<(), Self::Error> {
225        let (buffer, codec) = unsafe {
226            let this = self.get_unchecked_mut();
227            (&mut this.write_buf, &mut this.codec)
228        };
229        codec.reset();
230        loop {
231            let b = zero_buf(buffer.chunk_mut());
232            match codec.encode(&item, b) {
233                EncodeResult::Ok(len) => {
234                    // Safety: We made sure to zero the buffer above, so this is
235                    // safe assuming the advance_mut implementation won't
236                    // overflow the buffer
237                    unsafe { buffer.advance_mut(len) };
238                    return Ok(());
239                }
240                EncodeResult::Err(e) => return Err(WriteFrameError::Encode(e)),
241                EncodeResult::Overflow(0) => buffer.reserve(buffer.remaining_mut() * 2),
242                EncodeResult::Overflow(new_size) => buffer.reserve(new_size),
243            }
244        }
245    }
246    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
247        ready!(self.as_mut().poll_ready(cx))?;
248
249        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().stream) }
250            .poll_flush(cx)
251            .map_err(WriteFrameError::Io)
252    }
253
254    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
255        ready!(self.as_mut().poll_flush(cx))?;
256
257        unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().stream) }
258            .poll_close(cx)
259            .map_err(WriteFrameError::Io)
260    }
261}
262
263fn zero_buf(b: &mut UninitSlice) -> &mut [u8] {
264    for i in 0..b.len() {
265        b.write_byte(i, 0)
266    }
267    unsafe { std::mem::transmute(b) }
268}
269
270#[cfg(test)]
271mod test {
272    use std::str::Utf8Error;
273
274    use futures::{
275        io::Cursor,
276        prelude::*,
277    };
278
279    use super::*;
280
281    struct LineCodec;
282
283    impl Encode for LineCodec {
284        type Item = String;
285        type Error = ();
286        fn encode(&mut self, item: &String, buf: &mut [u8]) -> EncodeResult<()> {
287            let needed = item.as_bytes().len() + 1;
288            if buf.len() < needed {
289                return EncodeResult::Overflow(needed);
290            }
291            buf[..needed - 1].copy_from_slice(item.as_bytes());
292            buf[needed - 1] = b'\n';
293            Ok(needed).into()
294        }
295    }
296
297    impl Decode for LineCodec {
298        type Item = String;
299        type Error = Utf8Error;
300
301        fn decode(&mut self, buf: &mut [u8]) -> (usize, DecodeResult<String, Utf8Error>) {
302            let newline = match buf.iter().position(|b| *b == b'\n') {
303                Some(idx) => idx,
304                None => return (0, DecodeResult::UnexpectedEnd),
305            };
306            let string_bytes = &buf[..newline];
307            (
308                newline + 1,
309                std::str::from_utf8(string_bytes).map(String::from).into(),
310            )
311        }
312    }
313
314    const SHAKESPEARE: &str = r#"Now is the winter of our discontent
315Made glorious summer by this sun of York.
316Some are born great, some achieve greatness
317And some have greatness thrust upon them.
318Friends, Romans, countrymen - lend me your ears!
319I come not to praise Caesar, but to bury him.
320The evil that men do lives after them
321The good is oft interred with their bones.
322                    It is a tale
323Told by an idiot, full of sound and fury
324Signifying nothing.
325Ay me! For aught that I could ever read,
326Could ever hear by tale or history,
327The course of true love never did run smooth.
328I have full cause of weeping, but this heart
329Shall break into a hundred thousand flaws,
330Or ere I'll weep.-O Fool, I shall go mad!
331                    Each your doing,
332So singular in each particular,
333Crowns what you are doing in the present deed,
334That all your acts are queens.
335"#;
336
337    #[async_std::test]
338    async fn test_framed_stream() {
339        let reader = Cursor::new(Vec::from(SHAKESPEARE.as_bytes()));
340        let mut framed = Framed::new(reader, LineCodec);
341        let expected = SHAKESPEARE.lines().map(String::from).collect::<Vec<_>>();
342        let mut actual = vec![];
343        while let Some(frame) = framed.next().await.transpose().unwrap() {
344            actual.push(frame);
345        }
346        assert_eq!(actual, expected);
347    }
348
349    #[async_std::test]
350    async fn test_framed_sink() {
351        let frames = SHAKESPEARE.lines().map(String::from).collect::<Vec<_>>();
352        let mut actual = vec![0u8; SHAKESPEARE.as_bytes().len()];
353        {
354            let writer = Cursor::new(&mut actual);
355            let mut framed = Framed::new(writer, LineCodec);
356            for frame in frames {
357                framed.send(frame).await.unwrap();
358            }
359        }
360        assert_eq!(std::str::from_utf8(&actual).unwrap(), SHAKESPEARE);
361    }
362}