1use bytes_old::BufMut;
21use bytes_old::Bytes;
22use bytes_old::BytesMut;
23use futures::try_ready;
24use futures::Async;
25use futures::Poll;
26use futures::Stream;
27use tokio_io::codec::Decoder;
28
29pub fn decode<In, Dec>(input: In, decoder: Dec) -> LayeredDecode<In, Dec>
32where
33 In: Stream<Item = Bytes>,
34 Dec: Decoder,
35{
36 LayeredDecode {
37 input,
38 decoder,
39 buf: BytesMut::with_capacity(8 * 1024),
41 eof: false,
42 is_readable: false,
43 }
44}
45
46#[derive(Debug)]
48pub struct LayeredDecode<In, Dec> {
49 input: In,
50 decoder: Dec,
51 buf: BytesMut,
52 eof: bool,
53 is_readable: bool,
54}
55
56impl<In, Dec> Stream for LayeredDecode<In, Dec>
57where
58 In: Stream<Item = Bytes>,
59 Dec: Decoder,
60 Dec::Error: From<In::Error>,
61{
62 type Item = Dec::Item;
63 type Error = Dec::Error;
64
65 fn poll(&mut self) -> Poll<Option<Self::Item>, Dec::Error> {
66 loop {
70 if self.is_readable {
71 if self.eof {
72 let ret = if self.buf.is_empty() {
73 None
74 } else {
75 self.decoder.decode_eof(&mut self.buf)?
76 };
77 return Ok(Async::Ready(ret));
78 }
79 if let Some(frame) = self.decoder.decode(&mut self.buf)? {
80 return Ok(Async::Ready(Some(frame)));
81 }
82 self.is_readable = false;
83 }
84
85 assert!(!self.eof);
86
87 match try_ready!(self.input.poll()) {
88 Some(v) => {
89 self.buf.reserve(v.len());
90 self.buf.put(v);
91 }
92 None => self.eof = true,
93 }
94
95 self.is_readable = true;
96 }
97 }
98}
99
100impl<In, Dec> LayeredDecode<In, Dec>
101where
102 In: Stream<Item = Bytes>,
103{
104 #[inline]
106 pub fn into_inner(self) -> In {
107 self.input
109 }
110
111 #[inline]
113 pub fn get_ref(&self) -> &In {
114 &self.input
115 }
116
117 #[inline]
119 pub fn get_mut(&mut self) -> &mut In {
120 &mut self.input
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use std::io;
127
128 use anyhow::Error;
129 use anyhow::Result;
130 use bytes_old::Bytes;
131 use futures::stream;
132 use futures::Stream;
133 use futures03::compat::Future01CompatExt;
134
135 use super::*;
136
137 #[derive(Default)]
138 struct TestDecoder {}
139
140 impl Decoder for TestDecoder {
141 type Item = BytesMut;
142 type Error = Error;
143
144 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>> {
145 if !buf.is_empty() {
146 let expected_len: usize = u8::from_le(buf[0]).into();
147 if buf.len() > expected_len {
148 buf.split_to(1);
149 Ok(Some(buf.split_to(expected_len)))
150 } else {
151 Ok(None)
152 }
153 } else {
154 Ok(None)
155 }
156 }
157 }
158
159 #[test]
160 fn simple() {
161 let runtime = tokio::runtime::Runtime::new().unwrap();
162
163 let decoder = TestDecoder::default();
164
165 let inp = stream::iter_ok::<_, io::Error>(vec![Bytes::from(&b"\x0Dhello, world!"[..])]);
166
167 let dec = decode(inp, decoder);
168 let out = Vec::new();
169
170 let xfer = dec
171 .map_err::<(), _>(|err| {
172 panic!("bad = {err}");
173 })
174 .forward(out);
175
176 let (_, out) = runtime.block_on(xfer.compat()).unwrap();
177 let out = out
178 .into_iter()
179 .flat_map(|x| x.as_ref().to_vec())
180 .collect::<Vec<_>>();
181 assert_eq!(out, b"hello, world!");
182 }
183
184 #[test]
185 fn large() {
186 let runtime = tokio::runtime::Runtime::new().unwrap();
187
188 let decoder = TestDecoder::default();
189
190 let inp =
191 stream::iter_ok::<_, io::Error>(vec![Bytes::from("\x0Dhello, world!".repeat(5000))]);
192
193 let dec = decode(inp, decoder);
194 let out = Vec::new();
195
196 let xfer = dec
197 .map_err::<(), _>(|err| {
198 panic!("bad = {err}");
199 })
200 .forward(out);
201
202 let (_, out) = runtime.block_on(xfer.compat()).unwrap();
203 let out = out
204 .into_iter()
205 .flat_map(|x| x.as_ref().to_vec())
206 .collect::<Vec<_>>();
207
208 assert_eq!(out, "hello, world!".repeat(5000).as_bytes());
209 }
210
211 #[test]
212 fn partial() {
213 let runtime = tokio::runtime::Runtime::new().unwrap();
214
215 let decoder = TestDecoder::default();
216
217 let inp = stream::iter_ok::<_, io::Error>(vec![
218 Bytes::from(&b"\x0Dhel"[..]),
219 Bytes::from(&b"lo, world!"[..]),
220 ]);
221
222 let dec = decode(inp, decoder);
223 let out = Vec::new();
224
225 let xfer = dec
226 .map_err::<(), _>(|err| {
227 panic!("bad = {err}");
228 })
229 .forward(out);
230
231 let (_, out) = runtime.block_on(xfer.compat()).unwrap();
232 let out = out
233 .into_iter()
234 .flat_map(|x| x.as_ref().to_vec())
235 .collect::<Vec<_>>();
236 assert_eq!(out, b"hello, world!");
237 }
238}