1use bincode::config;
2use byteorder::{ByteOrder, NetworkEndian};
3use bytes::buf::Buf;
4use bytes::BytesMut;
5use futures_core::ready;
6use serde::Deserialize;
7use std::io;
8use std::marker::PhantomData;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12macro_rules! make_reader {
13 ($read_trait:path, $internal_poll_reader:path) => {
14 #[doc=concat!("[`", stringify!($read_trait), "`],")]
19 #[derive(Debug)]
26 pub struct AsyncBincodeReader<R, T>(crate::reader::AsyncBincodeReader<R, T>);
27
28 impl<R, T> Unpin for AsyncBincodeReader<R, T> where R: Unpin {}
29
30 impl<R, T> Default for AsyncBincodeReader<R, T>
31 where
32 R: Default,
33 {
34 fn default() -> Self {
35 Self::from(R::default())
36 }
37 }
38
39 impl<R, T> From<R> for AsyncBincodeReader<R, T> {
40 fn from(reader: R) -> Self {
41 Self(crate::reader::AsyncBincodeReader {
42 buffer: ::bytes::BytesMut::with_capacity(8192),
43 reader,
44 into: ::std::marker::PhantomData,
45 })
46 }
47 }
48
49 impl<R, T> AsyncBincodeReader<R, T> {
50 pub fn get_ref(&self) -> &R {
54 &self.0.reader
55 }
56
57 pub fn get_mut(&mut self) -> &mut R {
61 &mut self.0.reader
62 }
63
64 pub fn buffer(&self) -> &[u8] {
68 &self.0.buffer[..]
69 }
70
71 pub fn into_inner(self) -> R {
75 self.0.reader
76 }
77 }
78
79 impl<R, T> ::futures_core::Stream for AsyncBincodeReader<R, T>
80 where
81 for<'a> T: ::serde::Deserialize<'a>,
82 R: $read_trait + Unpin,
83 {
84 type Item = Result<T, bincode::error::DecodeError>;
85 fn poll_next(
86 mut self: std::pin::Pin<&mut Self>,
87 cx: &mut std::task::Context,
88 ) -> std::task::Poll<Option<Self::Item>> {
89 std::pin::Pin::new(&mut self.0).internal_poll_next(cx, $internal_poll_reader)
90 }
91 }
92 };
93}
94
95#[derive(Debug)]
96pub(crate) struct AsyncBincodeReader<R, T> {
97 pub(crate) reader: R,
98 pub(crate) buffer: BytesMut,
99 pub(crate) into: PhantomData<T>,
100}
101
102impl<R, T> Unpin for AsyncBincodeReader<R, T> where R: Unpin {}
103
104enum FillResult {
105 Filled,
106 Eof,
107}
108
109impl<R: Unpin, T> AsyncBincodeReader<R, T>
110where
111 for<'a> T: Deserialize<'a>,
112{
113 pub(crate) fn internal_poll_next<F>(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context,
116 poll_reader: F,
117 ) -> Poll<Option<Result<T, bincode::error::DecodeError>>>
118 where
119 F: Fn(Pin<&mut R>, &mut Context, &mut [u8]) -> Poll<Result<usize, io::Error>> + Copy,
120 {
121 if let FillResult::Eof = ready!(self.as_mut().fill(cx, 5, poll_reader).map_err(|inner| {
122 bincode::error::DecodeError::Io {
123 inner,
124 additional: 4,
125 }
126 }))? {
127 return Poll::Ready(None);
128 }
129
130 let message_size: u32 = NetworkEndian::read_u32(&self.buffer[..4]);
131 let target_buffer_size = message_size as usize;
132
133 ready!(self
135 .as_mut()
136 .fill(cx, target_buffer_size + 4, poll_reader)
137 .map_err(|inner| {
138 bincode::error::DecodeError::Io {
139 inner,
140 additional: target_buffer_size,
141 }
142 }))?;
143
144 self.buffer.advance(4);
145 let (message, decoded) = bincode::serde::decode_from_slice(
146 &self.buffer[..target_buffer_size],
147 config::standard().with_limit::<{ u32::MAX as usize }>(),
148 )?;
149 if decoded != target_buffer_size {
150 return Poll::Ready(Some(Err(bincode::error::DecodeError::OtherString(
151 format!("only decoded {decoded} out of {target_buffer_size}-length message"),
152 ))));
153 }
154 self.buffer.advance(target_buffer_size);
155 Poll::Ready(Some(Ok(message)))
156 }
157
158 fn fill<F>(
159 mut self: Pin<&mut Self>,
160 cx: &mut Context,
161 target_size: usize,
162 poll_reader: F,
163 ) -> Poll<Result<FillResult, io::Error>>
164 where
165 F: Fn(Pin<&mut R>, &mut Context, &mut [u8]) -> Poll<Result<usize, io::Error>>,
166 {
167 if self.buffer.len() >= target_size {
168 return Poll::Ready(Ok(FillResult::Filled));
170 }
171
172 if self.buffer.capacity() < target_size {
175 let missing = target_size - self.buffer.capacity();
176 self.buffer.reserve(missing);
177 }
178
179 let had = self.buffer.len();
180 let mut rest = self.buffer.split_off(had);
182 let max = rest.capacity();
185 unsafe { rest.set_len(max) };
186
187 while self.buffer.len() < target_size {
188 match poll_reader(Pin::new(&mut self.reader), cx, &mut rest[..]) {
189 Poll::Ready(result) => {
190 match result {
191 Ok(n) => {
192 if n == 0 {
193 if self.buffer.is_empty() {
194 return Poll::Ready(Ok(FillResult::Eof));
195 } else {
196 return Poll::Ready(Err(io::Error::from(
197 io::ErrorKind::BrokenPipe,
198 )));
199 }
200 }
201
202 let read = rest.split_to(n);
204 self.buffer.unsplit(read);
205 }
206 Err(err) => {
207 rest.truncate(0);
209 self.buffer.unsplit(rest);
210 return Poll::Ready(Err(err));
211 }
212 }
213 }
214 Poll::Pending => {
215 rest.truncate(0);
217 self.buffer.unsplit(rest);
218 return Poll::Pending;
219 }
220 }
221 }
222
223 Poll::Ready(Ok(FillResult::Filled))
224 }
225}