1use crate::{
2 ChecksumEnabled, Error, CHECKSUM_DISABLED, CHECKSUM_ENABLED, PROTOCOL_VERSION, U16_MARKER,
3 U32_MARKER, U64_MARKER, ZST_MARKER,
4};
5use bincode::Options;
6use futures_core::Stream;
7use futures_io::AsyncRead;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use siphasher::sip::SipHasher;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13use std::mem::size_of;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17#[derive(Debug)]
19pub struct AsyncReadTyped<R, T: Serialize + DeserializeOwned + Unpin> {
20 raw: R,
21 size_limit: u64,
22 state: AsyncReadState,
23 item_buffer: Vec<u8>,
24 checksum_read_state: ChecksumReadState,
25 _phantom: PhantomData<T>,
26}
27
28#[derive(Debug)]
29pub(crate) enum AsyncReadState {
30 ReadingVersion {
31 version_in_progress: [u8; 8],
32 version_in_progress_assigned: usize,
33 },
34 ReadingChecksumEnabled,
35 Idle,
36 ReadingLen {
37 len_read_mode: LenReadMode,
38 len_in_progress: [u8; 8],
39 len_in_progress_assigned: usize,
40 },
41 ReadingItem {
42 len_read: usize,
43 },
44 ReadingChecksum {
45 checksum_in_progress: [u8; 8],
46 checksum_assigned: usize,
47 },
48 Finished,
49}
50
51#[derive(Debug, PartialEq, Eq, Clone, Copy)]
52pub(crate) enum ChecksumReadState {
53 No,
55 Yes,
58 Ignore,
61}
62
63impl<R: AsyncRead + Unpin, T: Serialize + DeserializeOwned + Unpin> Stream
64 for AsyncReadTyped<R, T>
65{
66 type Item = Result<T, Error>;
67
68 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
69 let Self {
70 ref mut raw,
71 ref size_limit,
72 ref mut item_buffer,
73 ref mut state,
74 ref mut checksum_read_state,
75 _phantom,
76 } = &mut *self;
77 Self::poll_next_impl(
78 state,
79 raw,
80 item_buffer,
81 *size_limit,
82 checksum_read_state,
83 cx,
84 )
85 }
86}
87
88impl<R: AsyncRead + Unpin, T: Serialize + DeserializeOwned + Unpin> AsyncReadTyped<R, T> {
89 pub fn new_with_limit(raw: R, size_limit: u64, checksum_enabled: ChecksumEnabled) -> Self {
93 Self {
94 raw,
95 size_limit,
96 state: AsyncReadState::ReadingVersion {
97 version_in_progress: [0; 8],
98 version_in_progress_assigned: 0,
99 },
100 item_buffer: Vec::new(),
101 checksum_read_state: checksum_enabled.into(),
102 _phantom: PhantomData,
103 }
104 }
105
106 pub fn new(raw: R, checksum_enabled: ChecksumEnabled) -> Self {
108 Self::new_with_limit(raw, 1024u64.pow(2), checksum_enabled)
109 }
110
111 pub fn inner(&self) -> &R {
113 &self.raw
114 }
115
116 pub fn into_inner(self) -> R {
118 self.raw
119 }
120
121 pub fn optimize_memory_usage(&mut self) {
127 match self.state {
128 AsyncReadState::ReadingItem { .. } => self.item_buffer.shrink_to_fit(),
129 _ => {
130 self.item_buffer = Vec::new();
131 }
132 }
133 }
134
135 pub fn current_memory_usage(&self) -> usize {
138 self.item_buffer.capacity()
139 }
140
141 pub fn checksum_enabled(&self) -> bool {
144 self.checksum_read_state == ChecksumReadState::Yes
145 }
146
147 pub(crate) fn poll_next_impl(
148 state: &mut AsyncReadState,
149 mut raw: &mut R,
150 item_buffer: &mut Vec<u8>,
151 size_limit: u64,
152 checksum_read_state: &mut ChecksumReadState,
153 cx: &mut Context,
154 ) -> Poll<Option<Result<T, Error>>> {
155 loop {
156 return match state {
157 AsyncReadState::ReadingVersion {
158 version_in_progress,
159 version_in_progress_assigned,
160 } => {
161 while *version_in_progress_assigned < size_of::<u64>() {
162 let len = futures_core::ready!(Pin::new(&mut raw).poll_read(
163 cx,
164 &mut version_in_progress[(*version_in_progress_assigned)..]
165 ))?;
166 *version_in_progress_assigned += len;
167 }
168 let version = u64::from_le_bytes(*version_in_progress);
169 if version != PROTOCOL_VERSION {
170 *state = AsyncReadState::Finished;
171 return Poll::Ready(Some(Err(Error::ProtocolVersionMismatch {
172 our_version: PROTOCOL_VERSION,
173 their_version: version,
174 })));
175 }
176 *state = AsyncReadState::ReadingChecksumEnabled;
177 continue;
178 }
179 AsyncReadState::ReadingChecksumEnabled => {
180 let mut checksum_enabled = [0];
181 if futures_core::ready!(Pin::new(&mut raw).poll_read(cx, &mut checksum_enabled))?
182 == 1
183 {
184 match checksum_enabled[0] {
185 CHECKSUM_ENABLED => {
186 match *checksum_read_state {
187 ChecksumReadState::Yes => {
188 }
190 ChecksumReadState::No => {
191 *checksum_read_state = ChecksumReadState::Ignore;
194 }
195 ChecksumReadState::Ignore => {
196 }
198 }
199 }
200 CHECKSUM_DISABLED => {
201 *checksum_read_state = ChecksumReadState::No;
203 }
204 _ => {
205 *state = AsyncReadState::Finished;
206 return Poll::Ready(Some(Err(Error::ChecksumHandshakeFailed {
207 checksum_value: checksum_enabled[0],
208 })));
209 }
210 }
211 *state = AsyncReadState::Idle;
212 }
213 continue;
214 }
215 AsyncReadState::Idle => {
216 let mut buf = [0];
217 futures_core::ready!(Pin::new(&mut raw).poll_read(cx, &mut buf))?;
218 match buf[0] {
219 U16_MARKER => {
220 *state = AsyncReadState::ReadingLen {
221 len_read_mode: LenReadMode::U16,
222 len_in_progress: Default::default(),
223 len_in_progress_assigned: 0,
224 };
225 }
226 U32_MARKER => {
227 *state = AsyncReadState::ReadingLen {
228 len_read_mode: LenReadMode::U32,
229 len_in_progress: Default::default(),
230 len_in_progress_assigned: 0,
231 };
232 }
233 U64_MARKER => {
234 *state = AsyncReadState::ReadingLen {
235 len_read_mode: LenReadMode::U64,
236 len_in_progress: Default::default(),
237 len_in_progress_assigned: 0,
238 };
239 }
240 ZST_MARKER => {
241 item_buffer.truncate(0);
242 *state = AsyncReadState::ReadingItem { len_read: 0 };
243 }
244 0 => {
245 *state = AsyncReadState::Finished;
246 return Poll::Ready(None);
247 }
248 other => {
249 item_buffer.resize(other as usize, 0);
250 *state = AsyncReadState::ReadingItem { len_read: 0 };
251 }
252 }
253 continue;
254 }
255 AsyncReadState::ReadingLen {
256 ref mut len_read_mode,
257 ref mut len_in_progress,
258 ref mut len_in_progress_assigned,
259 } => {
260 let mut buf = [0; 8];
261 let accumulated = *len_in_progress_assigned;
262 let slice = match len_read_mode {
263 LenReadMode::U16 => &mut buf[accumulated..2],
264 LenReadMode::U32 => &mut buf[accumulated..4],
265 LenReadMode::U64 => &mut buf[accumulated..8],
266 };
267 let len = futures_core::ready!(Pin::new(&mut raw).poll_read(cx, slice))?;
268 len_in_progress[accumulated..(accumulated + len)]
269 .copy_from_slice(&slice[..len]);
270 *len_in_progress_assigned += len;
271 if len == slice.len() {
272 let new_len = match len_read_mode {
273 LenReadMode::U16 => u16::from_le_bytes(
274 (&len_in_progress[0..2]).try_into().expect("infallible"),
275 ) as u64,
276 LenReadMode::U32 => u32::from_le_bytes(
277 (&len_in_progress[0..4]).try_into().expect("infallible"),
278 ) as u64,
279 LenReadMode::U64 => u64::from_le_bytes(*len_in_progress),
280 };
281 if new_len > size_limit {
282 *state = AsyncReadState::Finished;
283 return Poll::Ready(Some(Err(Error::ReceivedMessageTooLarge)));
284 }
285 item_buffer.resize(new_len as usize, 0);
286 *state = AsyncReadState::ReadingItem { len_read: 0 };
287 }
288 continue;
289 }
290 AsyncReadState::ReadingItem { ref mut len_read } => {
291 while *len_read < item_buffer.len() {
292 let len = futures_core::ready!(
293 Pin::new(&mut raw).poll_read(cx, &mut item_buffer[*len_read..])
294 )?;
295 *len_read += len;
296 }
297 if [ChecksumReadState::Yes, ChecksumReadState::Ignore]
298 .contains(checksum_read_state)
299 {
300 *state = AsyncReadState::ReadingChecksum {
301 checksum_in_progress: [0; 8],
302 checksum_assigned: 0,
303 };
304 continue;
305 } else {
306 let ret = Poll::Ready(Some(
307 crate::bincode_options(size_limit)
308 .deserialize(item_buffer)
309 .map_err(Error::Bincode),
310 ));
311 *state = AsyncReadState::Idle;
312 ret
313 }
314 }
315 AsyncReadState::ReadingChecksum {
316 checksum_in_progress,
317 checksum_assigned,
318 } => {
319 while *checksum_assigned < size_of::<u64>() {
320 let len = futures_core::ready!(Pin::new(&mut raw)
321 .poll_read(cx, &mut checksum_in_progress[(*checksum_assigned)..]))?;
322 *checksum_assigned += len;
323 }
324 let ret = (*checksum_read_state == ChecksumReadState::Yes)
325 .then(|| {
326 let sent_checksum = u64::from_le_bytes(*checksum_in_progress);
327 let mut hasher = SipHasher::new();
328 hasher.write(item_buffer);
329 let computed_checksum = hasher.finish();
330 (sent_checksum != computed_checksum).then_some(Err(
331 Error::ChecksumMismatch {
332 sent_checksum,
333 computed_checksum,
334 },
335 ))
336 })
337 .flatten()
338 .unwrap_or_else(|| {
339 crate::bincode_options(size_limit)
340 .deserialize(item_buffer)
341 .map_err(Error::Bincode)
342 });
343 *state = AsyncReadState::Idle;
344 Poll::Ready(Some(ret))
345 }
346 AsyncReadState::Finished => Poll::Ready(None),
347 };
348 }
349 }
350}
351
352#[derive(Debug)]
353pub(crate) enum LenReadMode {
354 U16,
355 U32,
356 U64,
357}