hypercore_protocol/
reader.rs1use crate::crypto::DecryptCipher;
2use futures_lite::io::AsyncRead;
3use futures_timer::Delay;
4use std::future::Future;
5use std::io::{Error, ErrorKind, Result};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use crate::constants::{DEFAULT_TIMEOUT, MAX_MESSAGE_SIZE};
10use crate::message::{Frame, FrameType};
11use crate::util::stat_uint24_le;
12use std::time::Duration;
13
14const TIMEOUT: Duration = Duration::from_secs(DEFAULT_TIMEOUT as u64);
15const READ_BUF_INITIAL_SIZE: usize = 1024 * 128;
16
17#[derive(Debug)]
18pub(crate) struct ReadState {
19 buf: Vec<u8>,
21 start: usize,
23 end: usize,
25 step: Step,
27 timeout: Delay,
29 cipher: Option<DecryptCipher>,
31 frame_type: FrameType,
33}
34
35impl ReadState {
36 pub(crate) fn new() -> ReadState {
37 ReadState {
38 buf: vec![0u8; READ_BUF_INITIAL_SIZE],
39 start: 0,
40 end: 0,
41 step: Step::Header,
42 timeout: Delay::new(TIMEOUT),
43 cipher: None,
44 frame_type: FrameType::Raw,
45 }
46 }
47}
48
49#[derive(Debug)]
50enum Step {
51 Header,
52 Body {
53 header_len: usize,
54 body_len: usize,
55 },
56 Batch,
58}
59
60impl ReadState {
61 pub(crate) fn upgrade_with_decrypt_cipher(&mut self, decrypt_cipher: DecryptCipher) {
62 self.cipher = Some(decrypt_cipher);
63 }
64
65 pub(crate) fn decrypt_buf(&mut self, buf: &[u8]) -> Result<Vec<u8>> {
70 if let Some(cipher) = self.cipher.as_mut() {
71 Ok(cipher.decrypt_buf(buf)?.0)
72 } else {
73 Ok(buf.to_vec())
74 }
75 }
76
77 pub(crate) fn set_frame_type(&mut self, frame_type: FrameType) {
78 self.frame_type = frame_type;
79 }
80
81 pub(crate) fn poll_reader<R>(
82 &mut self,
83 cx: &mut Context<'_>,
84 mut reader: &mut R,
85 ) -> Poll<Result<Frame>>
86 where
87 R: AsyncRead + Unpin,
88 {
89 let mut incomplete = true;
90 loop {
91 if !incomplete {
92 if let Some(result) = self.process() {
93 return Poll::Ready(result);
94 }
95 } else {
96 incomplete = false;
97 }
98 let n = match Pin::new(&mut reader).poll_read(cx, &mut self.buf[self.end..]) {
99 Poll::Ready(Ok(n)) if n > 0 => n,
100 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
101 Poll::Pending | Poll::Ready(Ok(_)) => {
103 return Pin::new(&mut self.timeout)
106 .poll(cx)
107 .map(|()| Err(Error::new(ErrorKind::TimedOut, "Remote timed out")));
108 }
109 };
110
111 let end = self.end + n;
112 let (success, segments) = create_segments(&self.buf[self.start..end])?;
113 if success {
114 if let Some(ref mut cipher) = self.cipher {
115 let mut dec_end = self.start;
116 for (index, header_len, body_len) in segments {
117 let de = cipher.decrypt(
118 &mut self.buf[self.start + index..end],
119 header_len,
120 body_len,
121 )?;
122 dec_end = self.start + index + de;
123 }
124 self.end = dec_end;
125 } else {
126 self.end = end;
127 }
128 } else {
129 self.cycle_buf_and_resize_if_needed(segments[segments.len() - 1]);
132
133 incomplete = true;
135 }
136 self.timeout.reset(TIMEOUT);
137 }
138 }
139
140 fn cycle_buf_and_resize_if_needed(&mut self, last_segment: (usize, usize, usize)) {
141 let (last_index, last_header_len, last_body_len) = last_segment;
142 let total_incoming_length = last_index + last_header_len + last_body_len;
143 if self.buf.len() < total_incoming_length {
144 self.buf.resize(total_incoming_length, 0u8);
146 }
147 let temp = self.buf[self.start..].to_vec();
148 let len = temp.len();
149 self.buf[..len].copy_from_slice(&temp[..]);
150 self.end = len;
151 self.start = 0;
152 }
153
154 fn process(&mut self) -> Option<Result<Frame>> {
155 loop {
156 match self.step {
157 Step::Header => {
158 let stat = stat_uint24_le(&self.buf[self.start..self.end]);
159 if let Some((header_len, body_len)) = stat {
160 if body_len == 0 {
161 self.start += header_len;
163 return None;
164 } else if (self.start + header_len + body_len as usize) < self.end {
165 self.step = Step::Batch;
168 } else {
169 let body_len = body_len as usize;
170 if body_len > MAX_MESSAGE_SIZE as usize {
171 return Some(Err(Error::new(
172 ErrorKind::InvalidData,
173 "Message length above max allowed size",
174 )));
175 }
176 self.step = Step::Body {
177 header_len,
178 body_len,
179 };
180 }
181 } else {
182 return Some(Err(Error::new(ErrorKind::InvalidData, "Invalid header")));
183 }
184 }
185
186 Step::Body {
187 header_len,
188 body_len,
189 } => {
190 let message_len = header_len + body_len;
191 let range = self.start + header_len..self.start + message_len;
192 let frame = Frame::decode(&self.buf[range], &self.frame_type);
193 self.start += message_len;
194 self.step = Step::Header;
195 return Some(frame);
196 }
197 Step::Batch => {
198 let frame =
199 Frame::decode_multiple(&self.buf[self.start..self.end], &self.frame_type);
200 self.start = self.end;
201 self.step = Step::Header;
202 return Some(frame);
203 }
204 }
205 }
206 }
207}
208
209#[allow(clippy::type_complexity)]
210fn create_segments(buf: &[u8]) -> Result<(bool, Vec<(usize, usize, usize)>)> {
211 let mut index: usize = 0;
212 let len = buf.len();
213 let mut segments: Vec<(usize, usize, usize)> = vec![];
214 while index < len {
215 if let Some((header_len, body_len)) = stat_uint24_le(&buf[index..]) {
216 let body_len = body_len as usize;
217 segments.push((index, header_len, body_len));
218 if len < index + header_len + body_len {
219 return Ok((false, segments));
221 }
222 index += header_len + body_len;
223 } else {
224 return Err(Error::new(
225 ErrorKind::InvalidData,
226 "Could not read header while decrypting",
227 ));
228 }
229 }
230 Ok((true, segments))
231}