1use core::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5use std::io::{Error as IoError, ErrorKind as IoErrorKind};
6
7use futures_core::{ready, Stream};
8use futures_io::{AsyncRead, AsyncWrite};
9use futures_sink::Sink;
10use pin_project_lite::pin_project;
11
12pin_project! {
16 #[derive(Debug)]
17 pub struct Decoder<R> {
18 #[pin]
19 inner: R,
20 buf: Vec<u8>,
21 n_read: usize,
22 state: DecodeState,
23 }
24}
25
26impl<R: AsyncRead> Decoder<R> {
27 pub fn new(inner: R) -> Self {
28 Self::with_capacity(1024, inner)
29 }
30
31 pub fn with_capacity(cap: usize, inner: R) -> Self {
32 Self {
33 inner,
34 buf: vec![0; cap],
35 n_read: 0,
36 state: DecodeState::Head,
37 }
38 }
39
40 pub fn get_ref(&self) -> &R {
41 &self.inner
42 }
43
44 pub fn get_mut(&mut self) -> &mut R {
45 &mut self.inner
46 }
47
48 pub fn into_inner(self) -> R {
49 self.inner
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
54enum DecodeState {
55 Head,
56 Data(usize),
57}
58
59impl<R: AsyncRead> Stream for Decoder<R> {
60 type Item = Result<Vec<u8>, IoError>;
61
62 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63 let mut this = self.project();
64
65 let field_len = core::mem::size_of::<u64>();
66
67 loop {
68 match *this.state {
69 DecodeState::Head => {
70 if *this.n_read >= field_len {
71 let data_len =
72 u64::from_be_bytes(this.buf[..field_len].try_into().expect("Never"));
73
74 if this.buf.len() < data_len as usize {
75 this.buf.resize(data_len as usize, 0);
76 }
77 this.buf.rotate_left(field_len);
78 *this.n_read -= field_len;
79
80 *this.state = DecodeState::Data(data_len as usize);
81 continue;
82 }
83 }
84 DecodeState::Data(data_len) => {
85 if *this.n_read >= data_len {
86 let data = this.buf[..data_len].to_vec();
87
88 this.buf.rotate_left(data_len);
89 *this.n_read -= data_len;
90
91 *this.state = DecodeState::Head;
92
93 return Poll::Ready(Some(Ok(data)));
94 }
95 }
96 }
97
98 match ready!(this
99 .inner
100 .as_mut()
101 .poll_read(cx, &mut this.buf[*this.n_read..]))
102 {
103 Ok(n) => {
104 if n == 0 {
105 match *this.state {
106 DecodeState::Head => {
107 if *this.n_read == 0 {
108 return Poll::Ready(None);
109 } else {
110 return Poll::Ready(Some(Err(IoError::new(
111 IoErrorKind::Other,
112 format!("need more head, n:{}", field_len - *this.n_read),
113 ))));
114 }
115 }
116 DecodeState::Data(data_len) => {
117 if *this.n_read == 0 {
118 return Poll::Ready(Some(Err(IoError::new(
119 IoErrorKind::Other,
120 "no data".to_string(),
121 ))));
122 } else {
123 return Poll::Ready(Some(Err(IoError::new(
124 IoErrorKind::Other,
125 format!("need more data, n:{}", data_len - *this.n_read),
126 ))));
127 }
128 }
129 }
130 }
131 *this.n_read += n;
132 }
133 Err(err) => {
134 return Poll::Ready(Some(Err(err)));
136 }
137 }
138 }
139 }
140}
141
142pin_project! {
146 #[derive(Debug)]
147 pub struct Encoder<W> {
148 #[pin]
149 inner: W,
150 buf: Vec<u8>,
151 }
152}
153
154impl<W: AsyncWrite> Encoder<W> {
155 pub fn new(inner: W) -> Self {
156 Self::with_capacity(1024, inner)
157 }
158
159 pub fn with_capacity(cap: usize, inner: W) -> Self {
160 Self {
161 inner,
162 buf: Vec::with_capacity(cap),
163 }
164 }
165
166 pub fn get_ref(&self) -> &W {
167 &self.inner
168 }
169
170 pub fn get_mut(&mut self) -> &mut W {
171 &mut self.inner
172 }
173
174 pub fn into_inner(self) -> W {
175 self.inner
176 }
177}
178
179impl<T: AsRef<[u8]>, W: AsyncWrite> Sink<T> for Encoder<W> {
181 type Error = IoError;
182
183 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184 if !self.buf.is_empty() {
185 <Encoder<W> as Sink<T>>::poll_flush(self.as_mut(), cx)
186 } else {
187 Poll::Ready(Ok(()))
188 }
189 }
190
191 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
192 let this = self.project();
193
194 let data = item.as_ref();
195 let data_len = data.len() as u64;
196
197 this.buf.extend_from_slice(data_len.to_be_bytes().as_ref());
198 this.buf.extend_from_slice(data);
199
200 Ok(())
201 }
202
203 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
204 let mut this = self.project();
205
206 let mut n_write = 0;
207 while !this.buf[n_write..].is_empty() {
208 let n = ready!(this.inner.as_mut().poll_write(cx, &this.buf[n_write..]))?;
209 n_write += n;
210
211 if n == 0 {
212 return Poll::Ready(Err(IoErrorKind::WriteZero.into()));
213 }
214 }
215 this.buf.clear();
216
217 ready!(this.inner.as_mut().poll_flush(cx))?;
218
219 Poll::Ready(Ok(()))
220 }
221
222 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
223 ready!(<Encoder<W> as Sink<T>>::poll_flush(self.as_mut(), cx))?;
224
225 let mut this = self.project();
226 ready!(this.inner.as_mut().poll_close(cx))?;
227
228 Poll::Ready(Ok(()))
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 use futures_util::{io::Cursor, SinkExt as _, StreamExt as _};
237
238 #[test]
239 fn simple() -> Result<(), Box<dyn std::error::Error>> {
240 futures_executor::block_on(async {
241 let cursor: Cursor<Vec<u8>> = Cursor::new(vec![]);
242
243 let mut decoder = Decoder::new(cursor);
244 assert!(decoder.next().await.is_none());
245
246 let cursor = decoder.into_inner();
247
248 let mut encoder = Encoder::new(cursor);
249 encoder.send(&"abc").await?;
250 encoder.send(&"12").await?;
251 encoder.send(&[]).await?;
252
253 let mut cursor = encoder.into_inner();
254 cursor.set_position(0);
255
256 let mut decoder = Decoder::new(cursor);
257 assert_eq!(
258 decoder.next().await.ok_or("decoder.next() is_none")??,
259 b"abc"
260 );
261 assert_eq!(
262 decoder.next().await.ok_or("decoder.next() is_none")??,
263 b"12"
264 );
265 assert_eq!(decoder.next().await.ok_or("decoder.next() is_none")??, b"");
266 assert!(decoder.next().await.is_none());
267
268 Ok(())
269 })
270 }
271
272 #[test]
273 fn test_decoder() -> Result<(), Box<dyn std::error::Error>> {
274 futures_executor::block_on(async {
275 let mut r: Cursor<Vec<u8>> = Cursor::new(vec![
276 0, 0, 0, 0, 0, 0, 0, 3, 97, 98, 99, ]);
279 r.set_position(0);
280 let mut decoder = Decoder::new(r);
281 assert_eq!(
282 decoder.next().await.ok_or("decoder.next() is_none")??,
283 b"abc"
284 );
285 assert!(decoder.next().await.is_none());
286
287 let mut r: Cursor<Vec<u8>> = Cursor::new(vec![
288 0, 0, 0, 0, 0, 0, 0, 3, 97, 98, 99, 0, 0, 0,
291 ]);
292 r.set_position(0);
293 let mut decoder = Decoder::new(r);
294 assert_eq!(
295 decoder.next().await.ok_or("decoder.next() is_none")??,
296 b"abc"
297 );
298 match decoder.next().await {
299 Some(Err(err)) => {
300 assert_eq!(err.kind(), IoErrorKind::Other);
301 assert!(err.to_string().contains("need more head, n:5"));
302 }
303 x => panic!("{x:?}"),
304 };
305
306 let mut r: Cursor<Vec<u8>> = Cursor::new(vec![
307 0, 0, 0, 0, 0, 0, 0, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 3, 4, 5, 6, ]);
314 r.set_position(0);
315 let mut decoder = Decoder::new(r);
316 assert_eq!(
317 decoder.next().await.ok_or("decoder.next() is_none")??,
318 &[1, 2]
319 );
320 assert_eq!(decoder.next().await.ok_or("decoder.next() is_none")??, &[3]);
321 assert_eq!(
322 decoder.next().await.ok_or("decoder.next() is_none")??,
323 &[4, 5, 6]
324 );
325 assert!(decoder.next().await.is_none());
326
327 Ok(())
328 })
329 }
330
331 #[test]
332 fn test_encoder() -> Result<(), Box<dyn std::error::Error>> {
333 futures_executor::block_on(async {
334 let w: Cursor<Vec<u8>> = Cursor::new(vec![]);
335 let mut encoder = Encoder::new(w);
336 encoder.send(&"abc").await?;
337 assert_eq!(
338 encoder.into_inner().get_ref(),
339 &[
340 0, 0, 0, 0, 0, 0, 0, 3, 97, 98, 99, ]
343 );
344
345 let w: Cursor<Vec<u8>> = Cursor::new(vec![]);
346 let mut encoder = Encoder::new(w);
347 encoder.send(&[1, 2]).await?;
348 encoder.send(&[3]).await?;
349 encoder.send(vec![4, 5, 6]).await?;
350 assert_eq!(
351 encoder.into_inner().get_ref(),
352 &[
353 0, 0, 0, 0, 0, 0, 0, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 3, 4, 5, 6, ]
360 );
361
362 Ok(())
363 })
364 }
365}