futures_length_delimited_frame/
lib.rs

1use core::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind};
6
7use futures_core::{ready, Stream};
8use futures_io::{AsyncRead, AsyncWrite};
9use futures_sink::Sink;
10use pin_project_lite::pin_project;
11
12//
13//
14//
15pin_project! {
16    #[derive(Debug)]
17    pub struct Decoder<R> {
18        #[pin]
19        inner: R,
20        buf: Vec<u8>,
21        n_read: usize,
22        state: DecodeState,
23    }
24}
25
26impl<R: AsyncRead> Decoder<R> {
27    pub fn new(inner: R) -> Self {
28        Self::with_capacity(1024, inner)
29    }
30
31    pub fn with_capacity(cap: usize, inner: R) -> Self {
32        Self {
33            inner,
34            buf: vec![0; cap],
35            n_read: 0,
36            state: DecodeState::Head,
37        }
38    }
39
40    pub fn get_ref(&self) -> &R {
41        &self.inner
42    }
43
44    pub fn get_mut(&mut self) -> &mut R {
45        &mut self.inner
46    }
47
48    pub fn into_inner(self) -> R {
49        self.inner
50    }
51}
52
53#[derive(Debug, Clone, Copy)]
54enum DecodeState {
55    Head,
56    Data(usize),
57}
58
59impl<R: AsyncRead> Stream for Decoder<R> {
60    type Item = Result<Vec<u8>, IoError>;
61
62    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63        let mut this = self.project();
64
65        let field_len = core::mem::size_of::<u64>();
66
67        loop {
68            match *this.state {
69                DecodeState::Head => {
70                    if *this.n_read >= field_len {
71                        let data_len =
72                            u64::from_be_bytes(this.buf[..field_len].try_into().expect("Never"));
73
74                        if this.buf.len() < data_len as usize {
75                            this.buf.resize(data_len as usize, 0);
76                        }
77                        this.buf.rotate_left(field_len);
78                        *this.n_read -= field_len;
79
80                        *this.state = DecodeState::Data(data_len as usize);
81                        continue;
82                    }
83                }
84                DecodeState::Data(data_len) => {
85                    if *this.n_read >= data_len {
86                        let data = this.buf[..data_len].to_vec();
87
88                        this.buf.rotate_left(data_len);
89                        *this.n_read -= data_len;
90
91                        *this.state = DecodeState::Head;
92
93                        return Poll::Ready(Some(Ok(data)));
94                    }
95                }
96            }
97
98            match ready!(this
99                .inner
100                .as_mut()
101                .poll_read(cx, &mut this.buf[*this.n_read..]))
102            {
103                Ok(n) => {
104                    if n == 0 {
105                        match *this.state {
106                            DecodeState::Head => {
107                                if *this.n_read == 0 {
108                                    return Poll::Ready(None);
109                                } else {
110                                    return Poll::Ready(Some(Err(IoError::new(
111                                        IoErrorKind::Other,
112                                        format!("need more head, n:{}", field_len - *this.n_read),
113                                    ))));
114                                }
115                            }
116                            DecodeState::Data(data_len) => {
117                                if *this.n_read == 0 {
118                                    return Poll::Ready(Some(Err(IoError::new(
119                                        IoErrorKind::Other,
120                                        "no data".to_string(),
121                                    ))));
122                                } else {
123                                    return Poll::Ready(Some(Err(IoError::new(
124                                        IoErrorKind::Other,
125                                        format!("need more data, n:{}", data_len - *this.n_read),
126                                    ))));
127                                }
128                            }
129                        }
130                    }
131                    *this.n_read += n;
132                }
133                Err(err) => {
134                    //
135                    return Poll::Ready(Some(Err(err)));
136                }
137            }
138        }
139    }
140}
141
142//
143//
144//
145pin_project! {
146    #[derive(Debug)]
147    pub struct Encoder<W> {
148        #[pin]
149        inner: W,
150        buf: Vec<u8>,
151    }
152}
153
154impl<W: AsyncWrite> Encoder<W> {
155    pub fn new(inner: W) -> Self {
156        Self::with_capacity(1024, inner)
157    }
158
159    pub fn with_capacity(cap: usize, inner: W) -> Self {
160        Self {
161            inner,
162            buf: Vec::with_capacity(cap),
163        }
164    }
165
166    pub fn get_ref(&self) -> &W {
167        &self.inner
168    }
169
170    pub fn get_mut(&mut self) -> &mut W {
171        &mut self.inner
172    }
173
174    pub fn into_inner(self) -> W {
175        self.inner
176    }
177}
178
179// https://github.com/tokio-rs/tokio/blob/tokio-util-0.7.7/tokio-util/src/codec/framed_impl.rs#L253
180impl<T: AsRef<[u8]>, W: AsyncWrite> Sink<T> for Encoder<W> {
181    type Error = IoError;
182
183    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184        if !self.buf.is_empty() {
185            <Encoder<W> as Sink<T>>::poll_flush(self.as_mut(), cx)
186        } else {
187            Poll::Ready(Ok(()))
188        }
189    }
190
191    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
192        let this = self.project();
193
194        let data = item.as_ref();
195        let data_len = data.len() as u64;
196
197        this.buf.extend_from_slice(data_len.to_be_bytes().as_ref());
198        this.buf.extend_from_slice(data);
199
200        Ok(())
201    }
202
203    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
204        let mut this = self.project();
205
206        let mut n_write = 0;
207        while !this.buf[n_write..].is_empty() {
208            let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[n_write..]))?;
209            n_write += n;
210
211            if n == 0 {
212                return Poll::Ready(Err(IoErrorKind::WriteZero.into()));
213            }
214        }
215        this.buf.clear();
216
217        ready!(this.inner.as_mut().poll_flush(cx))?;
218
219        Poll::Ready(Ok(()))
220    }
221
222    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223        ready!(<Encoder<W> as Sink<T>>::poll_flush(self.as_mut(), cx))?;
224
225        let mut this = self.project();
226        ready!(this.inner.as_mut().poll_close(cx))?;
227
228        Poll::Ready(Ok(()))
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    use futures_util::{io::Cursor, SinkExt as _, StreamExt as _};
237
238    #[test]
239    fn simple() -> Result<(), Box<dyn std::error::Error>> {
240        futures_executor::block_on(async {
241            let cursor: Cursor<Vec<u8>> = Cursor::new(vec![]);
242
243            let mut decoder = Decoder::new(cursor);
244            assert!(decoder.next().await.is_none());
245
246            let cursor = decoder.into_inner();
247
248            let mut encoder = Encoder::new(cursor);
249            encoder.send(&"abc").await?;
250            encoder.send(&"12").await?;
251            encoder.send(&[]).await?;
252
253            let mut cursor = encoder.into_inner();
254            cursor.set_position(0);
255
256            let mut decoder = Decoder::new(cursor);
257            assert_eq!(
258                decoder.next().await.ok_or("decoder.next() is_none")??,
259                b"abc"
260            );
261            assert_eq!(
262                decoder.next().await.ok_or("decoder.next() is_none")??,
263                b"12"
264            );
265            assert_eq!(decoder.next().await.ok_or("decoder.next() is_none")??, b"");
266            assert!(decoder.next().await.is_none());
267
268            Ok(())
269        })
270    }
271
272    #[test]
273    fn test_decoder() -> Result<(), Box<dyn std::error::Error>> {
274        futures_executor::block_on(async {
275            let mut r: Cursor<Vec<u8>> = Cursor::new(vec![
276                0, 0, 0, 0, 0, 0, 0, 3, //
277                97, 98, 99, //
278            ]);
279            r.set_position(0);
280            let mut decoder = Decoder::new(r);
281            assert_eq!(
282                decoder.next().await.ok_or("decoder.next() is_none")??,
283                b"abc"
284            );
285            assert!(decoder.next().await.is_none());
286
287            let mut r: Cursor<Vec<u8>> = Cursor::new(vec![
288                0, 0, 0, 0, 0, 0, 0, 3, //
289                97, 98, 99, //
290                0, 0, 0,
291            ]);
292            r.set_position(0);
293            let mut decoder = Decoder::new(r);
294            assert_eq!(
295                decoder.next().await.ok_or("decoder.next() is_none")??,
296                b"abc"
297            );
298            match decoder.next().await {
299                Some(Err(err)) => {
300                    assert_eq!(err.kind(), IoErrorKind::Other);
301                    assert!(err.to_string().contains("need more head, n:5"));
302                }
303                x => panic!("{x:?}"),
304            };
305
306            let mut r: Cursor<Vec<u8>> = Cursor::new(vec![
307                0, 0, 0, 0, 0, 0, 0, 2, //
308                1, 2, //
309                0, 0, 0, 0, 0, 0, 0, 1, //
310                3, //
311                0, 0, 0, 0, 0, 0, 0, 3, //
312                4, 5, 6, //
313            ]);
314            r.set_position(0);
315            let mut decoder = Decoder::new(r);
316            assert_eq!(
317                decoder.next().await.ok_or("decoder.next() is_none")??,
318                &[1, 2]
319            );
320            assert_eq!(decoder.next().await.ok_or("decoder.next() is_none")??, &[3]);
321            assert_eq!(
322                decoder.next().await.ok_or("decoder.next() is_none")??,
323                &[4, 5, 6]
324            );
325            assert!(decoder.next().await.is_none());
326
327            Ok(())
328        })
329    }
330
331    #[test]
332    fn test_encoder() -> Result<(), Box<dyn std::error::Error>> {
333        futures_executor::block_on(async {
334            let w: Cursor<Vec<u8>> = Cursor::new(vec![]);
335            let mut encoder = Encoder::new(w);
336            encoder.send(&"abc").await?;
337            assert_eq!(
338                encoder.into_inner().get_ref(),
339                &[
340                    0, 0, 0, 0, 0, 0, 0, 3, //
341                    97, 98, 99, //
342                ]
343            );
344
345            let w: Cursor<Vec<u8>> = Cursor::new(vec![]);
346            let mut encoder = Encoder::new(w);
347            encoder.send(&[1, 2]).await?;
348            encoder.send(&[3]).await?;
349            encoder.send(vec![4, 5, 6]).await?;
350            assert_eq!(
351                encoder.into_inner().get_ref(),
352                &[
353                    0, 0, 0, 0, 0, 0, 0, 2, //
354                    1, 2, //
355                    0, 0, 0, 0, 0, 0, 0, 1, //
356                    3, //
357                    0, 0, 0, 0, 0, 0, 0, 3, //
358                    4, 5, 6, //
359                ]
360            );
361
362            Ok(())
363        })
364    }
365}