async_bincode/
reader.rs

1use bincode::config;
2use byteorder::{ByteOrder, NetworkEndian};
3use bytes::buf::Buf;
4use bytes::BytesMut;
5use futures_core::ready;
6use serde::Deserialize;
7use std::io;
8use std::marker::PhantomData;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12macro_rules! make_reader {
13    ($read_trait:path, $internal_poll_reader:path) => {
14        /// A wrapper around an asynchronous reader that produces an asynchronous stream of
15        /// bincode-decoded values.
16        ///
17        /// To use, provide a reader that implements
18        #[doc=concat!("[`", stringify!($read_trait), "`],")]
19        /// and then use [`futures_core::Stream`] to access the deserialized values.
20        ///
21        /// Note that the sender *must* prefix each serialized item with its size as reported by
22        /// [`bincode::serialized_size`] encoded as a four-byte network-endian encoded. Use the
23        /// marker trait [`AsyncDestination`] to add it automatically when using
24        /// [`AsyncBincodeWriter`].
25        #[derive(Debug)]
26        pub struct AsyncBincodeReader<R, T>(crate::reader::AsyncBincodeReader<R, T>);
27
28        impl<R, T> Unpin for AsyncBincodeReader<R, T> where R: Unpin {}
29
30        impl<R, T> Default for AsyncBincodeReader<R, T>
31        where
32            R: Default,
33        {
34            fn default() -> Self {
35                Self::from(R::default())
36            }
37        }
38
39        impl<R, T> From<R> for AsyncBincodeReader<R, T> {
40            fn from(reader: R) -> Self {
41                Self(crate::reader::AsyncBincodeReader {
42                    buffer: ::bytes::BytesMut::with_capacity(8192),
43                    reader,
44                    into: ::std::marker::PhantomData,
45                })
46            }
47        }
48
49        impl<R, T> AsyncBincodeReader<R, T> {
50            /// Gets a reference to the underlying reader.
51            ///
52            /// It is inadvisable to directly read from the underlying reader.
53            pub fn get_ref(&self) -> &R {
54                &self.0.reader
55            }
56
57            /// Gets a mutable reference to the underlying reader.
58            ///
59            /// It is inadvisable to directly read from the underlying reader.
60            pub fn get_mut(&mut self) -> &mut R {
61                &mut self.0.reader
62            }
63
64            /// Returns a reference to the internally buffered data.
65            ///
66            /// This will not attempt to fill the buffer if it is empty.
67            pub fn buffer(&self) -> &[u8] {
68                &self.0.buffer[..]
69            }
70
71            /// Unwraps this `AsyncBincodeReader`, returning the underlying reader.
72            ///
73            /// Note that any leftover data in the internal buffer is lost.
74            pub fn into_inner(self) -> R {
75                self.0.reader
76            }
77        }
78
79        impl<R, T> ::futures_core::Stream for AsyncBincodeReader<R, T>
80        where
81            for<'a> T: ::serde::Deserialize<'a>,
82            R: $read_trait + Unpin,
83        {
84            type Item = Result<T, bincode::error::DecodeError>;
85            fn poll_next(
86                mut self: std::pin::Pin<&mut Self>,
87                cx: &mut std::task::Context,
88            ) -> std::task::Poll<Option<Self::Item>> {
89                std::pin::Pin::new(&mut self.0).internal_poll_next(cx, $internal_poll_reader)
90            }
91        }
92    };
93}
94
95#[derive(Debug)]
96pub(crate) struct AsyncBincodeReader<R, T> {
97    pub(crate) reader: R,
98    pub(crate) buffer: BytesMut,
99    pub(crate) into: PhantomData<T>,
100}
101
102impl<R, T> Unpin for AsyncBincodeReader<R, T> where R: Unpin {}
103
104enum FillResult {
105    Filled,
106    Eof,
107}
108
109impl<R: Unpin, T> AsyncBincodeReader<R, T>
110where
111    for<'a> T: Deserialize<'a>,
112{
113    pub(crate) fn internal_poll_next<F>(
114        mut self: Pin<&mut Self>,
115        cx: &mut Context,
116        poll_reader: F,
117    ) -> Poll<Option<Result<T, bincode::error::DecodeError>>>
118    where
119        F: Fn(Pin<&mut R>, &mut Context, &mut [u8]) -> Poll<Result<usize, io::Error>> + Copy,
120    {
121        if let FillResult::Eof = ready!(self.as_mut().fill(cx, 5, poll_reader).map_err(|inner| {
122            bincode::error::DecodeError::Io {
123                inner,
124                additional: 4,
125            }
126        }))? {
127            return Poll::Ready(None);
128        }
129
130        let message_size: u32 = NetworkEndian::read_u32(&self.buffer[..4]);
131        let target_buffer_size = message_size as usize;
132
133        // since self.buffer.len() >= 4, we know that we can't get an clean EOF here
134        ready!(self
135            .as_mut()
136            .fill(cx, target_buffer_size + 4, poll_reader)
137            .map_err(|inner| {
138                bincode::error::DecodeError::Io {
139                    inner,
140                    additional: target_buffer_size,
141                }
142            }))?;
143
144        self.buffer.advance(4);
145        let (message, decoded) = bincode::serde::decode_from_slice(
146            &self.buffer[..target_buffer_size],
147            config::standard().with_limit::<{ u32::MAX as usize }>(),
148        )?;
149        if decoded != target_buffer_size {
150            return Poll::Ready(Some(Err(bincode::error::DecodeError::OtherString(
151                format!("only decoded {decoded} out of {target_buffer_size}-length message"),
152            ))));
153        }
154        self.buffer.advance(target_buffer_size);
155        Poll::Ready(Some(Ok(message)))
156    }
157
158    fn fill<F>(
159        mut self: Pin<&mut Self>,
160        cx: &mut Context,
161        target_size: usize,
162        poll_reader: F,
163    ) -> Poll<Result<FillResult, io::Error>>
164    where
165        F: Fn(Pin<&mut R>, &mut Context, &mut [u8]) -> Poll<Result<usize, io::Error>>,
166    {
167        if self.buffer.len() >= target_size {
168            // we already have the bytes we need!
169            return Poll::Ready(Ok(FillResult::Filled));
170        }
171
172        // make sure we can fit all the data we're about to read
173        // and then some, so we don't do a gazillion syscalls
174        if self.buffer.capacity() < target_size {
175            let missing = target_size - self.buffer.capacity();
176            self.buffer.reserve(missing);
177        }
178
179        let had = self.buffer.len();
180        // this is the bit we'll be reading into
181        let mut rest = self.buffer.split_off(had);
182        // this is safe because we're not extending beyond the reserved capacity
183        // and we're never reading unwritten bytes
184        let max = rest.capacity();
185        unsafe { rest.set_len(max) };
186
187        while self.buffer.len() < target_size {
188            match poll_reader(Pin::new(&mut self.reader), cx, &mut rest[..]) {
189                Poll::Ready(result) => {
190                    match result {
191                        Ok(n) => {
192                            if n == 0 {
193                                if self.buffer.is_empty() {
194                                    return Poll::Ready(Ok(FillResult::Eof));
195                                } else {
196                                    return Poll::Ready(Err(io::Error::from(
197                                        io::ErrorKind::BrokenPipe,
198                                    )));
199                                }
200                            }
201
202                            // adopt the new bytes
203                            let read = rest.split_to(n);
204                            self.buffer.unsplit(read);
205                        }
206                        Err(err) => {
207                            // reading failed, put the buffer back
208                            rest.truncate(0);
209                            self.buffer.unsplit(rest);
210                            return Poll::Ready(Err(err));
211                        }
212                    }
213                }
214                Poll::Pending => {
215                    // reading in progress, put the buffer back
216                    rest.truncate(0);
217                    self.buffer.unsplit(rest);
218                    return Poll::Pending;
219                }
220            }
221        }
222
223        Poll::Ready(Ok(FillResult::Filled))
224    }
225}