hypercore_protocol/
reader.rs

1use crate::crypto::DecryptCipher;
2use futures_lite::io::AsyncRead;
3use futures_timer::Delay;
4use std::future::Future;
5use std::io::{Error, ErrorKind, Result};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE};
10use crate::message::{Frame, FrameType};
11use crate::util::stat_uint24_le;
12use std::time::Duration;
13
14const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64);
15const READ_BUF_INITIAL_SIZE: usize = 1024 * 128;
16
17#[derive(Debug)]
18pub(crate) struct ReadState {
19    /// The read buffer.
20    buf: Vec<u8>,
21    /// The start of the not-yet-processed byte range in the read buffer.
22    start: usize,
23    /// The end of the not-yet-processed byte range in the read buffer.
24    end: usize,
25    /// The logical state of the reading (either header or body).
26    step: Step,
27    /// The timeout after which the connection is closed.
28    timeout: Delay,
29    /// Optional decryption cipher.
30    cipher: Option<DecryptCipher>,
31    /// The frame type to be passed to the decoder.
32    frame_type: FrameType,
33}
34
35impl ReadState {
36    pub(crate) fn new() -> ReadState {
37        ReadState {
38            buf: vec![0u8; READ_BUF_INITIAL_SIZE],
39            start: 0,
40            end: 0,
41            step: Step::Header,
42            timeout: Delay::new(TIMEOUT),
43            cipher: None,
44            frame_type: FrameType::Raw,
45        }
46    }
47}
48
49#[derive(Debug)]
50enum Step {
51    Header,
52    Body {
53        header_len: usize,
54        body_len: usize,
55    },
56    /// Multiple messages one after another
57    Batch,
58}
59
60impl ReadState {
61    pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) {
62        self.cipher = Some(decrypt_cipher);
63    }
64
65    /// Decrypts a given buf with stored cipher, if present. Used to correct
66    /// the rare mistake that more than two messages came in where the first
67    /// one created the cipher, and the next one should have been decrypted
68    /// but wasn't.
69    pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result<Vec<u8>> {
70        if let Some(cipher) = self.cipher.as_mut() {
71            Ok(cipher.decrypt_buf(buf)?.0)
72        } else {
73            Ok(buf.to_vec())
74        }
75    }
76
77    pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) {
78        self.frame_type = frame_type;
79    }
80
81    pub(crate) fn poll_reader<R>(
82        &mut self,
83        cx: &mut Context<'_>,
84        mut reader: &mut R,
85    ) -> Poll<Result<Frame>>
86    where
87        R: AsyncRead + Unpin,
88    {
89        let mut incomplete = true;
90        loop {
91            if !incomplete {
92                if let Some(result) = self.process() {
93                    return Poll::Ready(result);
94                }
95            } else {
96                incomplete = false;
97            }
98            let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) {
99                Poll::Ready(Ok(n)) if n > 0 => n,
100                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
101                // If the reader is pending, poll the timeout.
102                Poll::Pending | Poll::Ready(Ok(_)) => {
103                    // Return Pending if the timeout is pending, or an error if the
104                    // timeout expired (i.e. returned Poll::Ready).
105                    return Pin::new(&mut self.timeout)
106                        .poll(cx)
107                        .map(|()| Err(Error::new(ErrorKind::TimedOut, "Remote timed out")));
108                }
109            };
110
111            let end = self.end + n;
112            let (success, segments) = create_segments(&self.buf[self.start..end])?;
113            if success {
114                if let Some(ref mut cipher) = self.cipher {
115                    let mut dec_end = self.start;
116                    for (index, header_len, body_len) in segments {
117                        let de = cipher.decrypt(
118                            &mut self.buf[self.start + index..end],
119                            header_len,
120                            body_len,
121                        )?;
122                        dec_end = self.start + index + de;
123                    }
124                    self.end = dec_end;
125                } else {
126                    self.end = end;
127                }
128            } else {
129                // Could not segment due to buffer being full, need to cycle the buffer
130                // and possibly resize it too if the message is too big.
131                self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]);
132
133                // Set incomplete flag to skip processing and instead poll more data
134                incomplete = true;
135            }
136            self.timeout.reset(TIMEOUT);
137        }
138    }
139
140    fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) {
141        let (last_index, last_header_len, last_body_len) = last_segment;
142        let total_incoming_length = last_index + last_header_len + last_body_len;
143        if self.buf.len() < total_incoming_length {
144            // The incoming segments will not fit into the buffer, need to resize it
145            self.buf.resize(total_incoming_length, 0u8);
146        }
147        let temp = self.buf[self.start..].to_vec();
148        let len = temp.len();
149        self.buf[..len].copy_from_slice(&temp[..]);
150        self.end = len;
151        self.start = 0;
152    }
153
154    fn process(&mut self) -> Option<Result<Frame>> {
155        loop {
156            match self.step {
157                Step::Header => {
158                    let stat = stat_uint24_le(&self.buf[self.start..self.end]);
159                    if let Some((header_len, body_len)) = stat {
160                        if body_len == 0 {
161                            // This is a keepalive message, just remain in Step::Header
162                            self.start += header_len;
163                            return None;
164                        } else if (self.start + header_len + body_len as usize) < self.end {
165                            // There are more than one message here, create a batch from all of
166                            // then
167                            self.step = Step::Batch;
168                        } else {
169                            let body_len = body_len as usize;
170                            if body_len > MAX_MESSAGE_SIZE as usize {
171                                return Some(Err(Error::new(
172                                    ErrorKind::InvalidData,
173                                    "Message length above max allowed size",
174                                )));
175                            }
176                            self.step = Step::Body {
177                                header_len,
178                                body_len,
179                            };
180                        }
181                    } else {
182                        return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header")));
183                    }
184                }
185
186                Step::Body {
187                    header_len,
188                    body_len,
189                } => {
190                    let message_len = header_len + body_len;
191                    let range = self.start + header_len..self.start + message_len;
192                    let frame = Frame::decode(&self.buf[range], &self.frame_type);
193                    self.start += message_len;
194                    self.step = Step::Header;
195                    return Some(frame);
196                }
197                Step::Batch => {
198                    let frame =
199                        Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type);
200                    self.start = self.end;
201                    self.step = Step::Header;
202                    return Some(frame);
203                }
204            }
205        }
206    }
207}
208
209#[allow(clippy::type_complexity)]
210fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> {
211    let mut index: usize = 0;
212    let len = buf.len();
213    let mut segments: Vec<(usize, usize, usize)> = vec![];
214    while index < len {
215        if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) {
216            let body_len = body_len as usize;
217            segments.push((index, header_len, body_len));
218            if len < index + header_len + body_len {
219                // The segments will not fit, return false to indicate that more needs to be read
220                return Ok((false, segments));
221            }
222            index += header_len + body_len;
223        } else {
224            return Err(Error::new(
225                ErrorKind::InvalidData,
226                "Could not read header while decrypting",
227            ));
228        }
229    }
230    Ok((true, segments))
231}