futures_01_ext/
decode.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10//! A layered `Decoder` adapter for `Stream` transformations
11//!
12//! This module implements an adapter to allow a `tokio_io::codec::Decoder` implementation
13//! to transform a `Stream` - specifically, decode from a `Stream` of `Bytes` into some
14//! structured type.
15//!
16//! This allows multiple protocols to be layered and composed with operations on `Streams`,
17//! rather than restricting all codec operations to `AsyncRead`/`AsyncWrite` operations on
18//! an underlying transport.
19
20use bytes_old::BufMut;
21use bytes_old::Bytes;
22use bytes_old::BytesMut;
23use futures::try_ready;
24use futures::Async;
25use futures::Poll;
26use futures::Stream;
27use tokio_io::codec::Decoder;
28
29/// Returns a stream that will yield decoded items that are the result of decoding
30/// [Bytes] of the underlying [Stream] by using the provided [Decoder]
31pub fn decode<In, Dec>(input: In, decoder: Dec) -> LayeredDecode<In, Dec>
32where
33    In: Stream<Item = Bytes>,
34    Dec: Decoder,
35{
36    LayeredDecode {
37        input,
38        decoder,
39        // 8KB is a reasonable default
40        buf: BytesMut::with_capacity(8 * 1024),
41        eof: false,
42        is_readable: false,
43    }
44}
45
46/// Stream returned by the [decode] function
47#[derive(Debug)]
48pub struct LayeredDecode<In, Dec> {
49    input: In,
50    decoder: Dec,
51    buf: BytesMut,
52    eof: bool,
53    is_readable: bool,
54}
55
56impl<In, Dec> Stream for LayeredDecode<In, Dec>
57where
58    In: Stream<Item = Bytes>,
59    Dec: Decoder,
60    Dec::Error: From<In::Error>,
61{
62    type Item = Dec::Item;
63    type Error = Dec::Error;
64
65    fn poll(&mut self) -> Poll<Option<Self::Item>, Dec::Error> {
66        // This is adapted from Framed::poll in tokio. This does its own thing
67        // because converting the Bytes input stream to an Io object and then
68        // running it through Framed is pointless.
69        loop {
70            if self.is_readable {
71                if self.eof {
72                    let ret = if self.buf.is_empty() {
73                        None
74                    } else {
75                        self.decoder.decode_eof(&mut self.buf)?
76                    };
77                    return Ok(Async::Ready(ret));
78                }
79                if let Some(frame) = self.decoder.decode(&mut self.buf)? {
80                    return Ok(Async::Ready(Some(frame)));
81                }
82                self.is_readable = false;
83            }
84
85            assert!(!self.eof);
86
87            match try_ready!(self.input.poll()) {
88                Some(v) => {
89                    self.buf.reserve(v.len());
90                    self.buf.put(v);
91                }
92                None => self.eof = true,
93            }
94
95            self.is_readable = true;
96        }
97    }
98}
99
100impl<In, Dec> LayeredDecode<In, Dec>
101where
102    In: Stream<Item = Bytes>,
103{
104    /// Consume this combinator and returned the underlying stream
105    #[inline]
106    pub fn into_inner(self) -> In {
107        // TODO: do we want to check that buf is empty? otherwise we might lose data
108        self.input
109    }
110
111    /// Returns reference to the underlying stream
112    #[inline]
113    pub fn get_ref(&self) -> &In {
114        &self.input
115    }
116
117    /// Returns mutable reference to the underlying stream
118    #[inline]
119    pub fn get_mut(&mut self) -> &mut In {
120        &mut self.input
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use std::io;
127
128    use anyhow::Error;
129    use anyhow::Result;
130    use bytes_old::Bytes;
131    use futures::stream;
132    use futures::Stream;
133    use futures03::compat::Future01CompatExt;
134
135    use super::*;
136
137    #[derive(Default)]
138    struct TestDecoder {}
139
140    impl Decoder for TestDecoder {
141        type Item = BytesMut;
142        type Error = Error;
143
144        fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>> {
145            if !buf.is_empty() {
146                let expected_len: usize = u8::from_le(buf[0]).into();
147                if buf.len() > expected_len {
148                    buf.split_to(1);
149                    Ok(Some(buf.split_to(expected_len)))
150                } else {
151                    Ok(None)
152                }
153            } else {
154                Ok(None)
155            }
156        }
157    }
158
159    #[test]
160    fn simple() {
161        let runtime = tokio::runtime::Runtime::new().unwrap();
162
163        let decoder = TestDecoder::default();
164
165        let inp = stream::iter_ok::<_, io::Error>(vec![Bytes::from(&b"\x0Dhello, world!"[..])]);
166
167        let dec = decode(inp, decoder);
168        let out = Vec::new();
169
170        let xfer = dec
171            .map_err::<(), _>(|err| {
172                panic!("bad = {err}");
173            })
174            .forward(out);
175
176        let (_, out) = runtime.block_on(xfer.compat()).unwrap();
177        let out = out
178            .into_iter()
179            .flat_map(|x| x.as_ref().to_vec())
180            .collect::<Vec<_>>();
181        assert_eq!(out, b"hello, world!");
182    }
183
184    #[test]
185    fn large() {
186        let runtime = tokio::runtime::Runtime::new().unwrap();
187
188        let decoder = TestDecoder::default();
189
190        let inp =
191            stream::iter_ok::<_, io::Error>(vec![Bytes::from("\x0Dhello, world!".repeat(5000))]);
192
193        let dec = decode(inp, decoder);
194        let out = Vec::new();
195
196        let xfer = dec
197            .map_err::<(), _>(|err| {
198                panic!("bad = {err}");
199            })
200            .forward(out);
201
202        let (_, out) = runtime.block_on(xfer.compat()).unwrap();
203        let out = out
204            .into_iter()
205            .flat_map(|x| x.as_ref().to_vec())
206            .collect::<Vec<_>>();
207
208        assert_eq!(out, "hello, world!".repeat(5000).as_bytes());
209    }
210
211    #[test]
212    fn partial() {
213        let runtime = tokio::runtime::Runtime::new().unwrap();
214
215        let decoder = TestDecoder::default();
216
217        let inp = stream::iter_ok::<_, io::Error>(vec![
218            Bytes::from(&b"\x0Dhel"[..]),
219            Bytes::from(&b"lo, world!"[..]),
220        ]);
221
222        let dec = decode(inp, decoder);
223        let out = Vec::new();
224
225        let xfer = dec
226            .map_err::<(), _>(|err| {
227                panic!("bad = {err}");
228            })
229            .forward(out);
230
231        let (_, out) = runtime.block_on(xfer.compat()).unwrap();
232        let out = out
233            .into_iter()
234            .flat_map(|x| x.as_ref().to_vec())
235            .collect::<Vec<_>>();
236        assert_eq!(out, b"hello, world!");
237    }
238}