async_io_typed/
read.rs

1use crate::{
2    ChecksumEnabled, Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER,
3    U32_MARKER, U64_MARKER, ZST_MARKER,
4};
5use bincode::Options;
6use futures_core::Stream;
7use futures_io::AsyncRead;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use siphasher::sip::SipHasher;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13use std::mem::size_of;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17/// Provides the ability to read `serde` compatible types from any type that implements `futures::io::AsyncRead`.
18#[derive(Debug)]
19pub struct AsyncReadTyped<R, T: Serialize + DeserializeOwned + Unpin> {
20    raw: R,
21    size_limit: u64,
22    state: AsyncReadState,
23    item_buffer: Vec<u8>,
24    checksum_read_state: ChecksumReadState,
25    _phantom: PhantomData<T>,
26}
27
28#[derive(Debug)]
29pub(crate) enum AsyncReadState {
30    ReadingVersion {
31        version_in_progress: [u8; 8],
32        version_in_progress_assigned: usize,
33    },
34    ReadingChecksumEnabled,
35    Idle,
36    ReadingLen {
37        len_read_mode: LenReadMode,
38        len_in_progress: [u8; 8],
39        len_in_progress_assigned: usize,
40    },
41    ReadingItem {
42        len_read: usize,
43    },
44    ReadingChecksum {
45        checksum_in_progress: [u8; 8],
46        checksum_assigned: usize,
47    },
48    Finished,
49}
50
51#[derive(Debug, PartialEq, Eq, Clone, Copy)]
52pub(crate) enum ChecksumReadState {
53    /// The writer will not send checksums, the reader can't use them.
54    No,
55    /// The writer will send checksums, and we want to validate against them.
56    /// Checksums are enabled for both sides.
57    Yes,
58    /// The writer will send checksums, and we want to ignore them. Checksums
59    /// are enabled for the writer, and disabled for the reader.
60    Ignore,
61}
62
63impl<R: AsyncRead + Unpin, T: Serialize + DeserializeOwned + Unpin> Stream
64    for AsyncReadTyped<R, T>
65{
66    type Item = Result<T, Error>;
67
68    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69        let Self {
70            ref mut raw,
71            ref size_limit,
72            ref mut item_buffer,
73            ref mut state,
74            ref mut checksum_read_state,
75            _phantom,
76        } = &mut *self;
77        Self::poll_next_impl(
78            state,
79            raw,
80            item_buffer,
81            *size_limit,
82            checksum_read_state,
83            cx,
84        )
85    }
86}
87
88impl<R: AsyncRead + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncReadTyped<R, T> {
89    /// Creates a typed reader, initializing it with the given size limit specified in bytes.
90    ///
91    /// Be careful, large limits might create a vulnerability to a Denial of Service attack.
92    pub fn new_with_limit(raw: R, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
93        Self {
94            raw,
95            size_limit,
96            state: AsyncReadState::ReadingVersion {
97                version_in_progress: [0; 8],
98                version_in_progress_assigned: 0,
99            },
100            item_buffer: Vec::new(),
101            checksum_read_state: checksum_enabled.into(),
102            _phantom: PhantomData,
103        }
104    }
105
106    /// Creates a typed reader, initializing it with a default size limit of 1 MB.
107    pub fn new(raw: R, checksum_enabled: ChecksumEnabled) -> Self {
108        Self::new_with_limit(raw, 1024u64.pow(2), checksum_enabled)
109    }
110
111    /// Returns a reference to the raw I/O primitive that this type is using.
112    pub fn inner(&self) -> &R {
113        &self.raw
114    }
115
116    /// Consumes this `AsyncReadTyped` and returns the raw I/O primitive that was being used.
117    pub fn into_inner(self) -> R {
118        self.raw
119    }
120
121    /// `AsyncReadTyped` keeps a memory buffer for receiving values which is the same size as the largest
122    /// message that's been received. If the message size varies a lot, you might find yourself wasting
123    /// memory space. This function will reduce the memory usage as much as is possible without impeding
124    /// functioning. Overuse of this function may cause excessive memory allocations when the buffer
125    /// needs to grow.
126    pub fn optimize_memory_usage(&mut self) {
127        match self.state {
128            AsyncReadState::ReadingItem { .. } => self.item_buffer.shrink_to_fit(),
129            _ => {
130                self.item_buffer = Vec::new();
131            }
132        }
133    }
134
135    /// Reports the size of the memory buffer used for receiving values. You can shrink this buffer as much as
136    /// possible with [`Self::optimize_memory_usage`].
137    pub fn current_memory_usage(&self) -> usize {
138        self.item_buffer.capacity()
139    }
140
141    /// Returns true if checksums are enabled for this channel. This may become false after receiving the first value.
142    /// If that happens, the writer may have disabled checksums, so there is no checksum for the reader to check.
143    pub fn checksum_enabled(&self) -> bool {
144        self.checksum_read_state == ChecksumReadState::Yes
145    }
146
147    pub(crate) fn poll_next_impl(
148        state: &mut AsyncReadState,
149        mut raw: &mut R,
150        item_buffer: &mut Vec<u8>,
151        size_limit: u64,
152        checksum_read_state: &mut ChecksumReadState,
153        cx: &mut Context,
154    ) -> Poll<Option<Result<T, Error>>> {
155        loop {
156            return match state {
157                AsyncReadState::ReadingVersion {
158                    version_in_progress,
159                    version_in_progress_assigned,
160                } => {
161                    while *version_in_progress_assigned < size_of::<u64>() {
162                        let len = futures_core::ready!(Pin::new(&mut raw).poll_read(
163                            cx,
164                            &mut version_in_progress[(*version_in_progress_assigned)..]
165                        ))?;
166                        *version_in_progress_assigned += len;
167                    }
168                    let version = u64::from_le_bytes(*version_in_progress);
169                    if version != PROTOCOL_VERSION {
170                        *state = AsyncReadState::Finished;
171                        return Poll::Ready(Some(Err(Error::ProtocolVersionMismatch {
172                            our_version: PROTOCOL_VERSION,
173                            their_version: version,
174                        })));
175                    }
176                    *state = AsyncReadState::ReadingChecksumEnabled;
177                    continue;
178                }
179                AsyncReadState::ReadingChecksumEnabled => {
180                    let mut checksum_enabled = [0];
181                    if futures_core::ready!(Pin::new(&mut raw).poll_read(cx, &mut checksum_enabled))?
182                        == 1
183                    {
184                        match checksum_enabled[0] {
185                            CHECKSUM_ENABLED => {
186                                match *checksum_read_state {
187                                    ChecksumReadState::Yes => {
188                                        // Do nothing, we are in agreement that a checksum should be used.
189                                    }
190                                    ChecksumReadState::No => {
191                                        // The peer is going to send checksums and we can't tell them to stop.
192                                        // Ignore the checksums.
193                                        *checksum_read_state = ChecksumReadState::Ignore;
194                                    }
195                                    ChecksumReadState::Ignore => {
196                                        // This should never happen, but if it does we can continue ignoring them I suppose.
197                                    }
198                                }
199                            }
200                            CHECKSUM_DISABLED => {
201                                // We can't use checksums if the peer won't send them, so disable them.
202                                *checksum_read_state = ChecksumReadState::No;
203                            }
204                            _ => {
205                                *state = AsyncReadState::Finished;
206                                return Poll::Ready(Some(Err(Error::ChecksumHandshakeFailed {
207                                    checksum_value: checksum_enabled[0],
208                                })));
209                            }
210                        }
211                        *state = AsyncReadState::Idle;
212                    }
213                    continue;
214                }
215                AsyncReadState::Idle => {
216                    let mut buf = [0];
217                    futures_core::ready!(Pin::new(&mut raw).poll_read(cx, &mut buf))?;
218                    match buf[0] {
219                        U16_MARKER => {
220                            *state = AsyncReadState::ReadingLen {
221                                len_read_mode: LenReadMode::U16,
222                                len_in_progress: Default::default(),
223                                len_in_progress_assigned: 0,
224                            };
225                        }
226                        U32_MARKER => {
227                            *state = AsyncReadState::ReadingLen {
228                                len_read_mode: LenReadMode::U32,
229                                len_in_progress: Default::default(),
230                                len_in_progress_assigned: 0,
231                            };
232                        }
233                        U64_MARKER => {
234                            *state = AsyncReadState::ReadingLen {
235                                len_read_mode: LenReadMode::U64,
236                                len_in_progress: Default::default(),
237                                len_in_progress_assigned: 0,
238                            };
239                        }
240                        ZST_MARKER => {
241                            item_buffer.truncate(0);
242                            *state = AsyncReadState::ReadingItem { len_read: 0 };
243                        }
244                        0 => {
245                            *state = AsyncReadState::Finished;
246                            return Poll::Ready(None);
247                        }
248                        other => {
249                            item_buffer.resize(other as usize, 0);
250                            *state = AsyncReadState::ReadingItem { len_read: 0 };
251                        }
252                    }
253                    continue;
254                }
255                AsyncReadState::ReadingLen {
256                    ref mut len_read_mode,
257                    ref mut len_in_progress,
258                    ref mut len_in_progress_assigned,
259                } => {
260                    let mut buf = [0; 8];
261                    let accumulated = *len_in_progress_assigned;
262                    let slice = match len_read_mode {
263                        LenReadMode::U16 => &mut buf[accumulated..2],
264                        LenReadMode::U32 => &mut buf[accumulated..4],
265                        LenReadMode::U64 => &mut buf[accumulated..8],
266                    };
267                    let len = futures_core::ready!(Pin::new(&mut raw).poll_read(cx, slice))?;
268                    len_in_progress[accumulated..(accumulated + len)]
269                        .copy_from_slice(&slice[..len]);
270                    *len_in_progress_assigned += len;
271                    if len == slice.len() {
272                        let new_len = match len_read_mode {
273                            LenReadMode::U16 => u16::from_le_bytes(
274                                (&len_in_progress[0..2]).try_into().expect("infallible"),
275                            ) as u64,
276                            LenReadMode::U32 => u32::from_le_bytes(
277                                (&len_in_progress[0..4]).try_into().expect("infallible"),
278                            ) as u64,
279                            LenReadMode::U64 => u64::from_le_bytes(*len_in_progress),
280                        };
281                        if new_len > size_limit {
282                            *state = AsyncReadState::Finished;
283                            return Poll::Ready(Some(Err(Error::ReceivedMessageTooLarge)));
284                        }
285                        item_buffer.resize(new_len as usize, 0);
286                        *state = AsyncReadState::ReadingItem { len_read: 0 };
287                    }
288                    continue;
289                }
290                AsyncReadState::ReadingItem { ref mut len_read } => {
291                    while *len_read < item_buffer.len() {
292                        let len = futures_core::ready!(
293                            Pin::new(&mut raw).poll_read(cx, &mut item_buffer[*len_read..])
294                        )?;
295                        *len_read += len;
296                    }
297                    if [ChecksumReadState::Yes, ChecksumReadState::Ignore]
298                        .contains(checksum_read_state)
299                    {
300                        *state = AsyncReadState::ReadingChecksum {
301                            checksum_in_progress: [0; 8],
302                            checksum_assigned: 0,
303                        };
304                        continue;
305                    } else {
306                        let ret = Poll::Ready(Some(
307                            crate::bincode_options(size_limit)
308                                .deserialize(item_buffer)
309                                .map_err(Error::Bincode),
310                        ));
311                        *state = AsyncReadState::Idle;
312                        ret
313                    }
314                }
315                AsyncReadState::ReadingChecksum {
316                    checksum_in_progress,
317                    checksum_assigned,
318                } => {
319                    while *checksum_assigned < size_of::<u64>() {
320                        let len = futures_core::ready!(Pin::new(&mut raw)
321                            .poll_read(cx, &mut checksum_in_progress[(*checksum_assigned)..]))?;
322                        *checksum_assigned += len;
323                    }
324                    let ret = (*checksum_read_state == ChecksumReadState::Yes)
325                        .then(|| {
326                            let sent_checksum = u64::from_le_bytes(*checksum_in_progress);
327                            let mut hasher = SipHasher::new();
328                            hasher.write(item_buffer);
329                            let computed_checksum = hasher.finish();
330                            (sent_checksum != computed_checksum).then_some(Err(
331                                Error::ChecksumMismatch {
332                                    sent_checksum,
333                                    computed_checksum,
334                                },
335                            ))
336                        })
337                        .flatten()
338                        .unwrap_or_else(|| {
339                            crate::bincode_options(size_limit)
340                                .deserialize(item_buffer)
341                                .map_err(Error::Bincode)
342                        });
343                    *state = AsyncReadState::Idle;
344                    Poll::Ready(Some(ret))
345                }
346                AsyncReadState::Finished => Poll::Ready(None),
347            };
348        }
349    }
350}
351
352#[derive(Debug)]
353pub(crate) enum LenReadMode {
354    U16,
355    U32,
356    U64,
357}